update-tests (#872)
* update-tests * unit test update * update tests * update tests * update kuzu query * update * update query * update args * fix bulk episode add * make handling better
This commit is contained in:
parent
119a43b8e4
commit
da6f3336bb
8 changed files with 84 additions and 36 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import csv
|
||||
import os
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
|
|
@ -57,14 +58,41 @@ async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size
|
|||
)
|
||||
|
||||
|
||||
async def neo4j_node_label_migration(driver: GraphDriver):
|
||||
def pop_last_n_group_ids(csv_file: str = 'group_ids.csv', count: int = 10):
|
||||
with open(csv_file) as file:
|
||||
reader = csv.reader(file)
|
||||
group_ids = [row[0] for row in reader]
|
||||
|
||||
total_count = len(group_ids)
|
||||
popped = group_ids[-count:]
|
||||
remaining = group_ids[:-count]
|
||||
|
||||
with open(csv_file, 'w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
for gid in remaining:
|
||||
writer.writerow([gid])
|
||||
|
||||
return popped, total_count
|
||||
|
||||
|
||||
async def get_group_ids(driver: GraphDriver):
|
||||
query = """MATCH (n:Episodic)
|
||||
RETURN DISTINCT n.group_id AS group_id"""
|
||||
|
||||
results, _, _ = await driver.execute_query(query)
|
||||
for result in results:
|
||||
group_id = result['group_id']
|
||||
await neo4j_node_group_labels(driver, group_id)
|
||||
group_ids = [result['group_id'] for result in results]
|
||||
|
||||
with open('group_ids.csv', 'w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
for gid in group_ids:
|
||||
writer.writerow([gid])
|
||||
|
||||
|
||||
async def neo4j_node_label_migration(driver: GraphDriver, batch_size: int = 10):
|
||||
group_ids, total = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
|
||||
while len(group_ids) > 0:
|
||||
await asyncio.gather(*[neo4j_node_group_labels(driver, group_id) for group_id in group_ids])
|
||||
group_ids, _ = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -77,6 +105,7 @@ async def main():
|
|||
user=neo4j_user,
|
||||
password=neo4j_password,
|
||||
)
|
||||
await get_group_ids(driver)
|
||||
await neo4j_node_label_migration(driver)
|
||||
await driver.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -286,18 +286,24 @@ class EpisodicNode(Node):
|
|||
}
|
||||
],
|
||||
)
|
||||
|
||||
episode_args = {
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
'group_id': self.group_id,
|
||||
'source_description': self.source_description,
|
||||
'content': self.content,
|
||||
'entity_edges': self.entity_edges,
|
||||
'created_at': self.created_at,
|
||||
'valid_at': self.valid_at,
|
||||
'source': self.source.value,
|
||||
}
|
||||
|
||||
if driver.provider == GraphProvider.NEO4J:
|
||||
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_episode_node_save_query(driver.provider),
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
group_label='Episodic_' + self.group_id.replace('-', ''),
|
||||
source_description=self.source_description,
|
||||
content=self.content,
|
||||
entity_edges=self.entity_edges,
|
||||
created_at=self.created_at,
|
||||
valid_at=self.valid_at,
|
||||
source=self.source.value,
|
||||
get_episode_node_save_query(driver.provider), **episode_args
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
|
|
|
|||
|
|
@ -605,9 +605,7 @@ async def node_fulltext_search(
|
|||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||
)
|
||||
query = (
|
||||
get_nodes_query(
|
||||
index_name, '$query', limit=limit, provider=driver.provider
|
||||
)
|
||||
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
||||
+ yield_query
|
||||
+ filter_query
|
||||
+ """
|
||||
|
|
|
|||
|
|
@ -119,7 +119,8 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
for episode in episodes:
|
||||
episode['source'] = str(episode['source'].value)
|
||||
episode.pop('labels', None)
|
||||
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
|
||||
if driver.provider == GraphProvider.NEO4J:
|
||||
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
|
||||
|
||||
nodes = []
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.19.0pre3"
|
||||
version = "0.19.0pre4"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -116,9 +116,7 @@ def mock_cross_encoder_client():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_bulk(
|
||||
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
|
||||
):
|
||||
async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client):
|
||||
if graph_driver.provider == GraphProvider.FALKORDB:
|
||||
pytest.skip('Skipping as test fails on FalkorDB')
|
||||
|
||||
|
|
@ -143,7 +141,7 @@ async def test_add_bulk(
|
|||
source_description='conversation message',
|
||||
content='Alice likes Bob',
|
||||
valid_at=now,
|
||||
entity_edges=[], # Filled in later
|
||||
entity_edges=[], # Filled in later
|
||||
)
|
||||
episode_node_2 = EpisodicNode(
|
||||
name='test_episode_2',
|
||||
|
|
@ -154,14 +152,14 @@ async def test_add_bulk(
|
|||
source_description='conversation message',
|
||||
content='Bob adores Alice',
|
||||
valid_at=now,
|
||||
entity_edges=[], # Filled in later
|
||||
entity_edges=[], # Filled in later
|
||||
)
|
||||
|
||||
# Create entity nodes
|
||||
entity_node_1 = EntityNode(
|
||||
name='test_entity_1',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'Person'],
|
||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
|
||||
created_at=now,
|
||||
summary='test_entity_1 summary',
|
||||
attributes={'age': 30, 'location': 'New York'},
|
||||
|
|
@ -171,7 +169,7 @@ async def test_add_bulk(
|
|||
entity_node_2 = EntityNode(
|
||||
name='test_entity_2',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'Person2'],
|
||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person2'],
|
||||
created_at=now,
|
||||
summary='test_entity_2 summary',
|
||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||
|
|
@ -181,7 +179,7 @@ async def test_add_bulk(
|
|||
entity_node_3 = EntityNode(
|
||||
name='test_entity_3',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'City', 'Location'],
|
||||
labels=['Entity', 'City', 'Entity_graphiti_test_group', 'Location'],
|
||||
created_at=now,
|
||||
summary='test_entity_3 summary',
|
||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||
|
|
@ -191,7 +189,7 @@ async def test_add_bulk(
|
|||
entity_node_4 = EntityNode(
|
||||
name='test_entity_4',
|
||||
group_id=group_id,
|
||||
labels=['Entity'],
|
||||
labels=['Entity', 'Entity_graphiti_test_group'],
|
||||
created_at=now,
|
||||
summary='test_entity_4 summary',
|
||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||
|
|
@ -269,8 +267,22 @@ async def test_add_bulk(
|
|||
mock_embedder,
|
||||
)
|
||||
|
||||
node_ids = [episode_node_1.uuid, episode_node_2.uuid, entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
|
||||
edge_ids = [episodic_edge_1.uuid, episodic_edge_2.uuid, episodic_edge_3.uuid, episodic_edge_4.uuid, entity_edge_1.uuid, entity_edge_2.uuid]
|
||||
node_ids = [
|
||||
episode_node_1.uuid,
|
||||
episode_node_2.uuid,
|
||||
entity_node_1.uuid,
|
||||
entity_node_2.uuid,
|
||||
entity_node_3.uuid,
|
||||
entity_node_4.uuid,
|
||||
]
|
||||
edge_ids = [
|
||||
episodic_edge_1.uuid,
|
||||
episodic_edge_2.uuid,
|
||||
episodic_edge_3.uuid,
|
||||
episodic_edge_4.uuid,
|
||||
entity_edge_1.uuid,
|
||||
entity_edge_2.uuid,
|
||||
]
|
||||
node_count = await get_node_count(graph_driver, node_ids)
|
||||
assert node_count == len(node_ids)
|
||||
edge_count = await get_edge_count(graph_driver, edge_ids)
|
||||
|
|
@ -290,7 +302,6 @@ async def test_add_bulk(
|
|||
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
|
||||
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)
|
||||
|
||||
|
||||
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
|
||||
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)
|
||||
|
||||
|
|
@ -317,6 +328,7 @@ async def test_add_bulk(
|
|||
retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
|
||||
await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_episode(
|
||||
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
|
||||
|
|
@ -342,7 +354,7 @@ async def test_remove_episode(
|
|||
source_description='conversation message',
|
||||
content='Alice likes Bob',
|
||||
valid_at=now,
|
||||
entity_edges=[], # Filled in later
|
||||
entity_edges=[], # Filled in later
|
||||
)
|
||||
|
||||
# Create entity nodes
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def sample_entity_node():
|
|||
uuid=str(uuid4()),
|
||||
name='Test Entity',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'Person'],
|
||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
|
||||
created_at=created_at,
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Entity Summary',
|
||||
|
|
@ -103,7 +103,9 @@ async def test_entity_node(sample_entity_node, graph_driver):
|
|||
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
|
||||
|
||||
# Get node by group ids
|
||||
retrieved = await EntityNode.get_by_group_ids(graph_driver, [group_id], limit=2, with_embeddings=True)
|
||||
retrieved = await EntityNode.get_by_group_ids(
|
||||
graph_driver, [group_id], limit=2, with_embeddings=True
|
||||
)
|
||||
assert len(retrieved) == 1
|
||||
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
|
||||
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.19.0rc3"
|
||||
version = "0.19.0rc4"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue