Test chunk_by_paragraph chunk numbering
This commit is contained in:
parent
84c98f16bb
commit
928e1075c6
3 changed files with 44 additions and 4 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
|
||||||
|
|
||||||
def test_node_initialization():
|
def test_node_initialization():
|
||||||
|
|
@ -12,11 +12,13 @@ def test_node_initialization():
|
||||||
assert len(node.status) == 2
|
assert len(node.status) == 2
|
||||||
assert np.all(node.status == 1)
|
assert np.all(node.status == 1)
|
||||||
|
|
||||||
|
|
||||||
def test_node_invalid_dimension():
|
def test_node_invalid_dimension():
|
||||||
"""Test that initializing a Node with a non-positive dimension raises an error."""
|
"""Test that initializing a Node with a non-positive dimension raises an error."""
|
||||||
with pytest.raises(ValueError, match="Dimension must be a positive integer"):
|
with pytest.raises(ValueError, match="Dimension must be a positive integer"):
|
||||||
Node("node1", dimension=0)
|
Node("node1", dimension=0)
|
||||||
|
|
||||||
|
|
||||||
def test_add_skeleton_neighbor():
|
def test_add_skeleton_neighbor():
|
||||||
"""Test adding a neighbor to a node."""
|
"""Test adding a neighbor to a node."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -24,6 +26,7 @@ def test_add_skeleton_neighbor():
|
||||||
node1.add_skeleton_neighbor(node2)
|
node1.add_skeleton_neighbor(node2)
|
||||||
assert node2 in node1.skeleton_neighbours
|
assert node2 in node1.skeleton_neighbours
|
||||||
|
|
||||||
|
|
||||||
def test_remove_skeleton_neighbor():
|
def test_remove_skeleton_neighbor():
|
||||||
"""Test removing a neighbor from a node."""
|
"""Test removing a neighbor from a node."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -32,6 +35,7 @@ def test_remove_skeleton_neighbor():
|
||||||
node1.remove_skeleton_neighbor(node2)
|
node1.remove_skeleton_neighbor(node2)
|
||||||
assert node2 not in node1.skeleton_neighbours
|
assert node2 not in node1.skeleton_neighbours
|
||||||
|
|
||||||
|
|
||||||
def test_add_skeleton_edge():
|
def test_add_skeleton_edge():
|
||||||
"""Test adding an edge updates both skeleton_edges and skeleton_neighbours."""
|
"""Test adding an edge updates both skeleton_edges and skeleton_neighbours."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -41,6 +45,7 @@ def test_add_skeleton_edge():
|
||||||
assert edge in node1.skeleton_edges
|
assert edge in node1.skeleton_edges
|
||||||
assert node2 in node1.skeleton_neighbours
|
assert node2 in node1.skeleton_neighbours
|
||||||
|
|
||||||
|
|
||||||
def test_remove_skeleton_edge():
|
def test_remove_skeleton_edge():
|
||||||
"""Test removing an edge updates both skeleton_edges and skeleton_neighbours."""
|
"""Test removing an edge updates both skeleton_edges and skeleton_neighbours."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -51,6 +56,7 @@ def test_remove_skeleton_edge():
|
||||||
assert edge not in node1.skeleton_edges
|
assert edge not in node1.skeleton_edges
|
||||||
assert node2 not in node1.skeleton_neighbours
|
assert node2 not in node1.skeleton_neighbours
|
||||||
|
|
||||||
|
|
||||||
def test_is_node_alive_in_dimension():
|
def test_is_node_alive_in_dimension():
|
||||||
"""Test checking node's alive status in a specific dimension."""
|
"""Test checking node's alive status in a specific dimension."""
|
||||||
node = Node("node1", dimension=2)
|
node = Node("node1", dimension=2)
|
||||||
|
|
@ -58,25 +64,30 @@ def test_is_node_alive_in_dimension():
|
||||||
node.status[1] = 0
|
node.status[1] = 0
|
||||||
assert not node.is_node_alive_in_dimension(1)
|
assert not node.is_node_alive_in_dimension(1)
|
||||||
|
|
||||||
|
|
||||||
def test_node_alive_invalid_dimension():
|
def test_node_alive_invalid_dimension():
|
||||||
"""Test that checking alive status with an invalid dimension raises an error."""
|
"""Test that checking alive status with an invalid dimension raises an error."""
|
||||||
node = Node("node1", dimension=1)
|
node = Node("node1", dimension=1)
|
||||||
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
||||||
node.is_node_alive_in_dimension(1)
|
node.is_node_alive_in_dimension(1)
|
||||||
|
|
||||||
|
|
||||||
def test_node_equality():
|
def test_node_equality():
|
||||||
"""Test equality between nodes."""
|
"""Test equality between nodes."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
node2 = Node("node1")
|
node2 = Node("node1")
|
||||||
assert node1 == node2
|
assert node1 == node2
|
||||||
|
|
||||||
|
|
||||||
def test_node_hash():
|
def test_node_hash():
|
||||||
"""Test hashing for Node."""
|
"""Test hashing for Node."""
|
||||||
node = Node("node1")
|
node = Node("node1")
|
||||||
assert hash(node) == hash("node1")
|
assert hash(node) == hash("node1")
|
||||||
|
|
||||||
|
|
||||||
### Tests for Edge ###
|
### Tests for Edge ###
|
||||||
|
|
||||||
|
|
||||||
def test_edge_initialization():
|
def test_edge_initialization():
|
||||||
"""Test that an Edge is initialized correctly."""
|
"""Test that an Edge is initialized correctly."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -89,6 +100,7 @@ def test_edge_initialization():
|
||||||
assert len(edge.status) == 2
|
assert len(edge.status) == 2
|
||||||
assert np.all(edge.status == 1)
|
assert np.all(edge.status == 1)
|
||||||
|
|
||||||
|
|
||||||
def test_edge_invalid_dimension():
|
def test_edge_invalid_dimension():
|
||||||
"""Test that initializing an Edge with a non-positive dimension raises an error."""
|
"""Test that initializing an Edge with a non-positive dimension raises an error."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -96,6 +108,7 @@ def test_edge_invalid_dimension():
|
||||||
with pytest.raises(ValueError, match="Dimensions must be a positive integer."):
|
with pytest.raises(ValueError, match="Dimensions must be a positive integer."):
|
||||||
Edge(node1, node2, dimension=0)
|
Edge(node1, node2, dimension=0)
|
||||||
|
|
||||||
|
|
||||||
def test_is_edge_alive_in_dimension():
|
def test_is_edge_alive_in_dimension():
|
||||||
"""Test checking edge's alive status in a specific dimension."""
|
"""Test checking edge's alive status in a specific dimension."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -105,6 +118,7 @@ def test_is_edge_alive_in_dimension():
|
||||||
edge.status[1] = 0
|
edge.status[1] = 0
|
||||||
assert not edge.is_edge_alive_in_dimension(1)
|
assert not edge.is_edge_alive_in_dimension(1)
|
||||||
|
|
||||||
|
|
||||||
def test_edge_alive_invalid_dimension():
|
def test_edge_alive_invalid_dimension():
|
||||||
"""Test that checking alive status with an invalid dimension raises an error."""
|
"""Test that checking alive status with an invalid dimension raises an error."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -113,6 +127,7 @@ def test_edge_alive_invalid_dimension():
|
||||||
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
||||||
edge.is_edge_alive_in_dimension(1)
|
edge.is_edge_alive_in_dimension(1)
|
||||||
|
|
||||||
|
|
||||||
def test_edge_equality_directed():
|
def test_edge_equality_directed():
|
||||||
"""Test equality between directed edges."""
|
"""Test equality between directed edges."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -121,6 +136,7 @@ def test_edge_equality_directed():
|
||||||
edge2 = Edge(node1, node2, directed=True)
|
edge2 = Edge(node1, node2, directed=True)
|
||||||
assert edge1 == edge2
|
assert edge1 == edge2
|
||||||
|
|
||||||
|
|
||||||
def test_edge_equality_undirected():
|
def test_edge_equality_undirected():
|
||||||
"""Test equality between undirected edges."""
|
"""Test equality between undirected edges."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -129,6 +145,7 @@ def test_edge_equality_undirected():
|
||||||
edge2 = Edge(node2, node1, directed=False)
|
edge2 = Edge(node2, node1, directed=False)
|
||||||
assert edge1 == edge2
|
assert edge1 == edge2
|
||||||
|
|
||||||
|
|
||||||
def test_edge_hash_directed():
|
def test_edge_hash_directed():
|
||||||
"""Test hashing for directed edges."""
|
"""Test hashing for directed edges."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
|
|
@ -136,9 +153,10 @@ def test_edge_hash_directed():
|
||||||
edge = Edge(node1, node2, directed=True)
|
edge = Edge(node1, node2, directed=True)
|
||||||
assert hash(edge) == hash((node1, node2))
|
assert hash(edge) == hash((node1, node2))
|
||||||
|
|
||||||
|
|
||||||
def test_edge_hash_undirected():
|
def test_edge_hash_undirected():
|
||||||
"""Test hashing for undirected edges."""
|
"""Test hashing for undirected edges."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
node2 = Node("node2")
|
node2 = Node("node2")
|
||||||
edge = Edge(node1, node2, directed=False)
|
edge = Edge(node1, node2, directed=False)
|
||||||
assert hash(edge) == hash(frozenset({node1, node2}))
|
assert hash(edge) == hash(frozenset({node1, node2}))
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -9,6 +9,7 @@ def setup_graph():
|
||||||
"""Fixture to initialize a CogneeGraph instance."""
|
"""Fixture to initialize a CogneeGraph instance."""
|
||||||
return CogneeGraph()
|
return CogneeGraph()
|
||||||
|
|
||||||
|
|
||||||
def test_add_node_success(setup_graph):
|
def test_add_node_success(setup_graph):
|
||||||
"""Test successful addition of a node."""
|
"""Test successful addition of a node."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
@ -16,6 +17,7 @@ def test_add_node_success(setup_graph):
|
||||||
graph.add_node(node)
|
graph.add_node(node)
|
||||||
assert graph.get_node("node1") == node
|
assert graph.get_node("node1") == node
|
||||||
|
|
||||||
|
|
||||||
def test_add_duplicate_node(setup_graph):
|
def test_add_duplicate_node(setup_graph):
|
||||||
"""Test adding a duplicate node raises an exception."""
|
"""Test adding a duplicate node raises an exception."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
@ -24,6 +26,7 @@ def test_add_duplicate_node(setup_graph):
|
||||||
with pytest.raises(ValueError, match="Node with id node1 already exists."):
|
with pytest.raises(ValueError, match="Node with id node1 already exists."):
|
||||||
graph.add_node(node)
|
graph.add_node(node)
|
||||||
|
|
||||||
|
|
||||||
def test_add_edge_success(setup_graph):
|
def test_add_edge_success(setup_graph):
|
||||||
"""Test successful addition of an edge."""
|
"""Test successful addition of an edge."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
@ -37,6 +40,7 @@ def test_add_edge_success(setup_graph):
|
||||||
assert edge in node1.skeleton_edges
|
assert edge in node1.skeleton_edges
|
||||||
assert edge in node2.skeleton_edges
|
assert edge in node2.skeleton_edges
|
||||||
|
|
||||||
|
|
||||||
def test_add_duplicate_edge(setup_graph):
|
def test_add_duplicate_edge(setup_graph):
|
||||||
"""Test adding a duplicate edge raises an exception."""
|
"""Test adding a duplicate edge raises an exception."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
@ -49,6 +53,7 @@ def test_add_duplicate_edge(setup_graph):
|
||||||
with pytest.raises(ValueError, match="Edge .* already exists in the graph."):
|
with pytest.raises(ValueError, match="Edge .* already exists in the graph."):
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
|
||||||
def test_get_node_success(setup_graph):
|
def test_get_node_success(setup_graph):
|
||||||
"""Test retrieving an existing node."""
|
"""Test retrieving an existing node."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
@ -56,11 +61,13 @@ def test_get_node_success(setup_graph):
|
||||||
graph.add_node(node)
|
graph.add_node(node)
|
||||||
assert graph.get_node("node1") == node
|
assert graph.get_node("node1") == node
|
||||||
|
|
||||||
|
|
||||||
def test_get_node_nonexistent(setup_graph):
|
def test_get_node_nonexistent(setup_graph):
|
||||||
"""Test retrieving a nonexistent node returns None."""
|
"""Test retrieving a nonexistent node returns None."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
assert graph.get_node("nonexistent") is None
|
assert graph.get_node("nonexistent") is None
|
||||||
|
|
||||||
|
|
||||||
def test_get_edges_success(setup_graph):
|
def test_get_edges_success(setup_graph):
|
||||||
"""Test retrieving edges of a node."""
|
"""Test retrieving edges of a node."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
@ -72,6 +79,7 @@ def test_get_edges_success(setup_graph):
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
assert edge in graph.get_edges("node1")
|
assert edge in graph.get_edges("node1")
|
||||||
|
|
||||||
|
|
||||||
def test_get_edges_nonexistent_node(setup_graph):
|
def test_get_edges_nonexistent_node(setup_graph):
|
||||||
"""Test retrieving edges for a nonexistent node raises an exception."""
|
"""Test retrieving edges for a nonexistent node raises an exception."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
|
|
|
||||||
|
|
@ -37,3 +37,17 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
|
||||||
assert np.all(
|
assert np.all(
|
||||||
chunk_lengths <= paragraph_length
|
chunk_lengths <= paragraph_length
|
||||||
), f"{paragraph_length = }: {larger_chunks} are too large"
|
), f"{paragraph_length = }: {larger_chunks} are too large"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_text,paragraph_length,batch_paragraphs",
|
||||||
|
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
|
||||||
|
)
|
||||||
|
def test_chunk_by_paragraph_chunk_numbering(
|
||||||
|
input_text, paragraph_length, batch_paragraphs
|
||||||
|
):
|
||||||
|
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
|
||||||
|
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||||
|
assert np.all(
|
||||||
|
chunk_indices == np.arange(len(chunk_indices))
|
||||||
|
), f"{chunk_indices = } are not monotonically increasing"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue