Merge pull request #378 from topoteretes/COG-748

feat: Add versioning to the data point model
This commit is contained in:
Vasilije 2025-01-16 20:04:51 +01:00 committed by GitHub
commit b61dfd0948
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 186 additions and 124 deletions

View file

@ -62,10 +62,12 @@ class Neo4jAdapter(GraphDBInterface):
async def add_node(self, node: DataPoint): async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump()) serialized_properties = self.serialize_properties(node.model_dump())
query = dedent("""MERGE (node {id: $node_id}) query = dedent(
"""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp() ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp() ON MATCH SET node += $properties, node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId""") RETURN ID(node) AS internal_id, node.id AS nodeId"""
)
params = { params = {
"node_id": str(node.id), "node_id": str(node.id),
@ -182,13 +184,15 @@ class Neo4jAdapter(GraphDBInterface):
): ):
serialized_properties = self.serialize_properties(edge_properties) serialized_properties = self.serialize_properties(edge_properties)
query = dedent("""MATCH (from_node {id: $from_node}), query = dedent(
"""MATCH (from_node {id: $from_node}),
(to_node {id: $to_node}) (to_node {id: $to_node})
MERGE (from_node)-[r]->(to_node) MERGE (from_node)-[r]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
ON MATCH SET r += $properties, r.updated_at = timestamp() ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r RETURN r
""") """
)
params = { params = {
"from_node": str(from_node), "from_node": str(from_node),

View file

@ -88,23 +88,27 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
} }
) )
return dedent(f""" return dedent(
f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}}) MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp() ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp() ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
""").strip() """
).strip()
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str: async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
properties = await self.stringify_properties(edge[3]) properties = await self.stringify_properties(edge[3])
properties = f"{{{properties}}}" properties = f"{{{properties}}}"
return dedent(f""" return dedent(
f"""
MERGE (source {{id:'{edge[0]}'}}) MERGE (source {{id:'{edge[0]}'}})
MERGE (target {{id: '{edge[1]}'}}) MERGE (target {{id: '{edge[1]}'}})
MERGE (source)-[edge:{edge[2]} {properties}]->(target) MERGE (source)-[edge:{edge[2]} {properties}]->(target)
ON MATCH SET edge.updated_at = timestamp() ON MATCH SET edge.updated_at = timestamp()
ON CREATE SET edge.updated_at = timestamp() ON CREATE SET edge.updated_at = timestamp()
""").strip() """
).strip()
async def create_collection(self, collection_name: str): async def create_collection(self, collection_name: str):
pass pass
@ -195,12 +199,14 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
self.query(query) self.query(query)
async def has_edges(self, edges): async def has_edges(self, edges):
query = dedent(""" query = dedent(
"""
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (a)-[r]->(b) MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
""").strip() """
).strip()
params = { params = {
"edges": [ "edges": [
@ -279,14 +285,16 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
[label, attribute_name] = collection_name.split(".") [label, attribute_name] = collection_name.split(".")
query = dedent(f""" query = dedent(
f"""
CALL db.idx.vector.queryNodes( CALL db.idx.vector.queryNodes(
'{label}', '{label}',
'{attribute_name}', '{attribute_name}',
{limit}, {limit},
vecf32({query_vector}) vecf32({query_vector})
) YIELD node, score ) YIELD node, score
""").strip() """
).strip()
result = self.query(query) result = self.query(query)

View file

@ -93,10 +93,12 @@ class SQLAlchemyAdapter:
if self.engine.dialect.name == "postgresql": if self.engine.dialect.name == "postgresql":
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
result = await connection.execute( result = await connection.execute(
text(""" text(
"""
SELECT schema_name FROM information_schema.schemata SELECT schema_name FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema'); WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
""") """
)
) )
return [schema[0] for schema in result.fetchall()] return [schema[0] for schema in result.fetchall()]
return [] return []

View file

@ -1,24 +1,34 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Optional, Any, Dict
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import TypedDict from typing_extensions import TypedDict
import pickle
# Define metadata type
class MetaData(TypedDict): class MetaData(TypedDict):
index_fields: list[str] index_fields: list[str]
# Updated DataPoint model with versioning and new fields
class DataPoint(BaseModel): class DataPoint(BaseModel):
__tablename__ = "data_point" __tablename__ = "data_point"
id: UUID = Field(default_factory=uuid4) id: UUID = Field(default_factory=uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc) created_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
)
updated_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
)
version: int = 1 # Default version
topological_rank: Optional[int] = 0 topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"} _metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
# class Config: # Override the Pydantic configuration
# underscore_attrs_are_private = True class Config:
underscore_attrs_are_private = True
@classmethod @classmethod
def get_embeddable_data(self, data_point): def get_embeddable_data(self, data_point):
@ -31,11 +41,11 @@ class DataPoint(BaseModel):
if isinstance(attribute, str): if isinstance(attribute, str):
return attribute.strip() return attribute.strip()
else: return attribute
return attribute
@classmethod @classmethod
def get_embeddable_properties(self, data_point): def get_embeddable_properties(self, data_point):
"""Retrieve all embeddable properties."""
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0: if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
return [ return [
getattr(data_point, field, None) for field in data_point._metadata["index_fields"] getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
@ -45,4 +55,40 @@ class DataPoint(BaseModel):
@classmethod @classmethod
def get_embeddable_property_names(self, data_point): def get_embeddable_property_names(self, data_point):
"""Retrieve names of embeddable properties."""
return data_point._metadata["index_fields"] or [] return data_point._metadata["index_fields"] or []
def update_version(self):
"""Update the version and updated_at timestamp."""
self.version += 1
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)
# JSON Serialization
def to_json(self) -> str:
"""Serialize the instance to a JSON string."""
return self.json()
@classmethod
def from_json(self, json_str: str):
"""Deserialize the instance from a JSON string."""
return self.model_validate_json(json_str)
# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())
@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return self(**data)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Serialize model to a dictionary."""
return self.model_dump(**kwargs)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
"""Deserialize model from a dictionary."""
return cls.model_validate(data)

View file

@ -19,9 +19,11 @@ async def index_and_transform_graphiti_nodes_and_edges():
raise RuntimeError("Initialization error") from e raise RuntimeError("Initialization error") from e
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""") await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id, await graph_engine.query(
"""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
r.target_node_id = target.id, r.target_node_id = target.id,
r.relationship_name = type(r) RETURN r""") r.relationship_name = type(r) RETURN r"""
)
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""") await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")
nodes_data, edges_data = await graph_engine.get_model_independent_graph_data() nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()

View file

@ -36,12 +36,12 @@ def test_AudioDocument():
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' ground_truth["word_count"] == paragraph_data.word_count
) ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -25,12 +25,12 @@ def test_ImageDocument():
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' ground_truth["word_count"] == paragraph_data.word_count
) ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -27,12 +27,12 @@ def test_PdfDocument():
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' ground_truth["word_count"] == paragraph_data.word_count
) ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker") GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' ground_truth["word_count"] == paragraph_data.word_count
) ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -71,32 +71,32 @@ def test_UnstructuredDocument():
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test DOCX # Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, ( assert (
f" sentence_end != {paragraph_data.cut_type = }" "sentence_end" == paragraph_data.cut_type
) ), f" sentence_end != {paragraph_data.cut_type = }"
# TEST CSV # TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( assert (
f"Read text doesn't match expected text: {paragraph_data.text}" "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
) ), f"Read text doesn't match expected text: {paragraph_data.text}"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test XLSX # Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"

View file

@ -30,9 +30,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert result[0]["name"] == "Natural_language_processing_copy", ( assert (
"Result name does not match expected value." result[0]["name"] == "Natural_language_processing_copy"
) ), "Result name does not match expected value."
result = await relational_engine.get_all_data_from_table("datasets") result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found." assert len(result) == 2, "Unexpected number of datasets found."
@ -61,9 +61,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], ( assert (
"Content hash is not a part of file name." hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
) ), "Content hash is not a part of file name."
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)

View file

@ -85,9 +85,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), ( assert not os.path.exists(
"SQLite relational database is not empty" get_relational_engine().db_path
) ), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -82,9 +82,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), ( assert not os.path.exists(
"SQLite relational database is not empty" get_relational_engine().db_path
) ), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest() data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents # Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee # Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), ( assert not os.path.exists(
f"Data location still exists after deletion: {data.raw_data_location}" data.raw_data_location
) ), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session: async with engine.get_async_session() as session:
# Get data entry from database based on file path # Get data entry from database based on file path
data = ( data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location)) await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one() ).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted # Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), ( assert os.path.exists(
f"Data location doesn't exists: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1): async def test_getting_of_documents(dataset_name_1):
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 1" len(document_ids) == 1
) ), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided # Test getting of documents for search when no dataset is provided
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id) document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 2" len(document_ids) == 2
) ), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main(): async def main():

View file

@ -17,9 +17,9 @@ batch_paragraphs_vals = [True, False]
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs): def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks]) reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -36,9 +36,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks]) chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length] larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
assert np.all(chunk_lengths <= paragraph_length), ( assert np.all(
f"{paragraph_length = }: {larger_chunks} are too large" chunk_lengths <= paragraph_length
) ), f"{paragraph_length = }: {larger_chunks} are too large"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -50,6 +50,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
) )
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( assert np.all(
f"{chunk_indices = } are not monotonically increasing" chunk_indices == np.arange(len(chunk_indices))
) ), f"{chunk_indices = } are not monotonically increasing"

View file

@ -58,9 +58,9 @@ def run_chunking_test(test_text, expected_chunks):
for expected_chunks_item, chunk in zip(expected_chunks, chunks): for expected_chunks_item, chunk in zip(expected_chunks, chunks):
for key in ["text", "word_count", "cut_type"]: for key in ["text", "word_count", "cut_type"]:
assert chunk[key] == expected_chunks_item[key], ( assert (
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }" chunk[key] == expected_chunks_item[key]
) ), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
def test_chunking_whole_text(): def test_chunking_whole_text():

View file

@ -16,9 +16,9 @@ maximum_length_vals = [None, 8, 64]
def test_chunk_by_sentence_isomorphism(input_text, maximum_length): def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
chunks = chunk_by_sentence(input_text, maximum_length) chunks = chunk_by_sentence(input_text, maximum_length)
reconstructed_text = "".join([chunk[1] for chunk in chunks]) reconstructed_text = "".join([chunk[1] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -36,6 +36,6 @@ def test_paragraph_chunk_length(input_text, maximum_length):
chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks]) chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > maximum_length] larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
assert np.all(chunk_lengths <= maximum_length), ( assert np.all(
f"{maximum_length = }: {larger_chunks} are too large" chunk_lengths <= maximum_length
) ), f"{maximum_length = }: {larger_chunks} are too large"

View file

@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
def test_chunk_by_word_isomorphism(input_text): def test_chunk_by_word_isomorphism(input_text):
chunks = chunk_by_word(input_text) chunks = chunk_by_word(input_text)
reconstructed_text = "".join([chunk[0] for chunk in chunks]) reconstructed_text = "".join([chunk[0] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(