reformat applied

This commit is contained in:
vasilije 2025-07-01 14:26:56 +02:00
parent 1ebeeac61d
commit 07f2afa69d
39 changed files with 577 additions and 531 deletions

View file

@ -5,43 +5,49 @@ Revises: 1d0bb7fede17
Create Date: 2025-01-27 12:00:00.000000 Create Date: 2025-01-27 12:00:00.000000
""" """
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from uuid import uuid4 from uuid import uuid4
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'incremental_file_signatures' revision = "incremental_file_signatures"
down_revision = '1d0bb7fede17' down_revision = "1d0bb7fede17"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('file_signatures', op.create_table(
sa.Column('id', sa.UUID(), nullable=False, default=uuid4), "file_signatures",
sa.Column('data_id', sa.UUID(), nullable=True), sa.Column("id", sa.UUID(), nullable=False, default=uuid4),
sa.Column('file_path', sa.String(), nullable=True), sa.Column("data_id", sa.UUID(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True), sa.Column("file_path", sa.String(), nullable=True),
sa.Column('content_hash', sa.String(), nullable=True), sa.Column("file_size", sa.Integer(), nullable=True),
sa.Column('total_blocks', sa.Integer(), nullable=True), sa.Column("content_hash", sa.String(), nullable=True),
sa.Column('block_size', sa.Integer(), nullable=True), sa.Column("total_blocks", sa.Integer(), nullable=True),
sa.Column('strong_len', sa.Integer(), nullable=True), sa.Column("block_size", sa.Integer(), nullable=True),
sa.Column('signature_data', sa.LargeBinary(), nullable=True), sa.Column("strong_len", sa.Integer(), nullable=True),
sa.Column('blocks_info', sa.JSON(), nullable=True), sa.Column("signature_data", sa.LargeBinary(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), sa.Column("blocks_info", sa.JSON(), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_file_signatures_data_id"), "file_signatures", ["data_id"], unique=False
)
op.create_index(
op.f("ix_file_signatures_content_hash"), "file_signatures", ["content_hash"], unique=False
) )
op.create_index(op.f('ix_file_signatures_data_id'), 'file_signatures', ['data_id'], unique=False)
op.create_index(op.f('ix_file_signatures_content_hash'), 'file_signatures', ['content_hash'], unique=False)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_file_signatures_content_hash'), table_name='file_signatures') op.drop_index(op.f("ix_file_signatures_content_hash"), table_name="file_signatures")
op.drop_index(op.f('ix_file_signatures_data_id'), table_name='file_signatures') op.drop_index(op.f("ix_file_signatures_data_id"), table_name="file_signatures")
op.drop_table('file_signatures') op.drop_table("file_signatures")
# ### end Alembic commands ### # ### end Alembic commands ###

View file

@ -18,6 +18,7 @@ litellm.set_verbose = False
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL) logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("litellm").setLevel(logging.CRITICAL) logging.getLogger("litellm").setLevel(logging.CRITICAL)
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
""" """
Adapter for Generic API LLM provider API. Adapter for Generic API LLM provider API.

View file

@ -28,7 +28,9 @@ class FileSignature(Base):
signature_data = Column(LargeBinary) signature_data = Column(LargeBinary)
# Block details (JSON array of block info) # Block details (JSON array of block info)
blocks_info = Column(JSON) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset} blocks_info = Column(
JSON
) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
# Timestamps # Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))

View file

@ -18,6 +18,7 @@ import tempfile
@dataclass @dataclass
class BlockInfo: class BlockInfo:
"""Information about a file block""" """Information about a file block"""
block_index: int block_index: int
weak_checksum: int weak_checksum: int
strong_hash: str strong_hash: str
@ -28,6 +29,7 @@ class BlockInfo:
@dataclass @dataclass
class FileSignature: class FileSignature:
"""File signature containing block information""" """File signature containing block information"""
file_path: str file_path: str
file_size: int file_size: int
total_blocks: int total_blocks: int
@ -40,6 +42,7 @@ class FileSignature:
@dataclass @dataclass
class FileDelta: class FileDelta:
"""Delta information for changed blocks""" """Delta information for changed blocks"""
changed_blocks: List[int] # Block indices that changed changed_blocks: List[int] # Block indices that changed
delta_data: bytes delta_data: bytes
old_signature: FileSignature old_signature: FileSignature
@ -80,9 +83,7 @@ class BlockHashService:
# Calculate optimal signature parameters # Calculate optimal signature parameters
magic, block_len, strong_len = get_signature_args( magic, block_len, strong_len = get_signature_args(
file_size, file_size, block_len=self.block_size, strong_len=self.strong_len
block_len=self.block_size,
strong_len=self.strong_len
) )
# Generate signature using librsync # Generate signature using librsync
@ -102,10 +103,12 @@ class BlockHashService:
block_size=block_len, block_size=block_len,
strong_len=strong_len, strong_len=strong_len,
blocks=blocks, blocks=blocks,
signature_data=signature_data signature_data=signature_data,
) )
def _parse_signature(self, signature_data: bytes, file_data: bytes, block_size: int) -> List[BlockInfo]: def _parse_signature(
self, signature_data: bytes, file_data: bytes, block_size: int
) -> List[BlockInfo]:
""" """
Parse signature data to extract block information Parse signature data to extract block information
@ -131,13 +134,15 @@ class BlockHashService:
# Calculate strong hash (MD5) # Calculate strong hash (MD5)
strong_hash = hashlib.md5(block_data).hexdigest() strong_hash = hashlib.md5(block_data).hexdigest()
blocks.append(BlockInfo( blocks.append(
block_index=i, BlockInfo(
weak_checksum=weak_checksum, block_index=i,
strong_hash=strong_hash, weak_checksum=weak_checksum,
block_size=len(block_data), strong_hash=strong_hash,
file_offset=start_offset block_size=len(block_data),
)) file_offset=start_offset,
)
)
return blocks return blocks
@ -185,15 +190,18 @@ class BlockHashService:
if old_block is None or new_block is None: if old_block is None or new_block is None:
# Block was added or removed # Block was added or removed
changed_blocks.append(block_idx) changed_blocks.append(block_idx)
elif (old_block.weak_checksum != new_block.weak_checksum or elif (
old_block.strong_hash != new_block.strong_hash): old_block.weak_checksum != new_block.weak_checksum
or old_block.strong_hash != new_block.strong_hash
):
# Block content changed # Block content changed
changed_blocks.append(block_idx) changed_blocks.append(block_idx)
return sorted(changed_blocks) return sorted(changed_blocks)
def generate_delta(self, old_file: BinaryIO, new_file: BinaryIO, def generate_delta(
old_signature: FileSignature = None) -> FileDelta: self, old_file: BinaryIO, new_file: BinaryIO, old_signature: FileSignature = None
) -> FileDelta:
""" """
Generate a delta between two file versions Generate a delta between two file versions
@ -226,7 +234,7 @@ class BlockHashService:
changed_blocks=changed_blocks, changed_blocks=changed_blocks,
delta_data=delta_data, delta_data=delta_data,
old_signature=old_signature, old_signature=old_signature,
new_signature=new_signature new_signature=new_signature,
) )
def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO: def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO:
@ -249,7 +257,9 @@ class BlockHashService:
return result_io return result_io
def calculate_block_changes(self, old_sig: FileSignature, new_sig: FileSignature) -> Dict[str, Any]: def calculate_block_changes(
self, old_sig: FileSignature, new_sig: FileSignature
) -> Dict[str, Any]:
""" """
Calculate detailed statistics about block changes Calculate detailed statistics about block changes

View file

@ -37,7 +37,9 @@ class IncrementalLoader:
""" """
self.block_service = BlockHashService(block_size, strong_len) self.block_service = BlockHashService(block_size, strong_len)
async def should_process_file(self, file_obj: BinaryIO, data_id: str) -> Tuple[bool, Optional[Dict]]: async def should_process_file(
self, file_obj: BinaryIO, data_id: str
) -> Tuple[bool, Optional[Dict]]:
""" """
Determine if a file should be processed based on incremental changes Determine if a file should be processed based on incremental changes
@ -74,18 +76,23 @@ class IncrementalLoader:
service_old_sig = self._db_signature_to_service(existing_signature) service_old_sig = self._db_signature_to_service(existing_signature)
# Compare signatures to find changed blocks # Compare signatures to find changed blocks
changed_blocks = self.block_service.compare_signatures(service_old_sig, current_signature) changed_blocks = self.block_service.compare_signatures(
service_old_sig, current_signature
)
if not changed_blocks: if not changed_blocks:
# Signatures match, no processing needed # Signatures match, no processing needed
return False, {"type": "no_changes", "full_processing": False} return False, {"type": "no_changes", "full_processing": False}
# Calculate change statistics # Calculate change statistics
change_stats = self.block_service.calculate_block_changes(service_old_sig, current_signature) change_stats = self.block_service.calculate_block_changes(
service_old_sig, current_signature
)
change_info = { change_info = {
"type": "incremental_changes", "type": "incremental_changes",
"full_processing": len(changed_blocks) > (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess "full_processing": len(changed_blocks)
> (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
"changed_blocks": changed_blocks, "changed_blocks": changed_blocks,
"stats": change_stats, "stats": change_stats,
"new_signature": current_signature, "new_signature": current_signature,
@ -94,8 +101,9 @@ class IncrementalLoader:
return True, change_info return True, change_info
async def process_incremental_changes(self, file_obj: BinaryIO, data_id: str, async def process_incremental_changes(
change_info: Dict) -> List[Dict]: self, file_obj: BinaryIO, data_id: str, change_info: Dict
) -> List[Dict]:
""" """
Process only the changed blocks of a file Process only the changed blocks of a file
@ -135,13 +143,15 @@ class IncrementalLoader:
end_offset = start_offset + block_info.block_size end_offset = start_offset + block_info.block_size
block_data = file_data[start_offset:end_offset] block_data = file_data[start_offset:end_offset]
changed_block_data.append({ changed_block_data.append(
"block_index": block_idx, {
"block_data": block_data, "block_index": block_idx,
"block_info": block_info, "block_data": block_data,
"file_offset": start_offset, "block_info": block_info,
"block_size": len(block_data), "file_offset": start_offset,
}) "block_size": len(block_data),
}
)
return changed_block_data return changed_block_data
@ -210,7 +220,9 @@ class IncrementalLoader:
await session.commit() await session.commit()
async def _get_existing_signature(self, session: AsyncSession, data_id: str) -> Optional[FileSignature]: async def _get_existing_signature(
self, session: AsyncSession, data_id: str
) -> Optional[FileSignature]:
""" """
Get existing file signature from database Get existing file signature from database

View file

@ -114,7 +114,9 @@ async def ingest_data(
incremental_loader = IncrementalLoader() incremental_loader = IncrementalLoader()
# Check if file needs incremental processing # Check if file needs incremental processing
should_process, change_info = await incremental_loader.should_process_file(file, data_id) should_process, change_info = await incremental_loader.should_process_file(
file, data_id
)
# Save updated file signature regardless of whether processing is needed # Save updated file signature regardless of whether processing is needed
await incremental_loader.save_file_signature(file, data_id) await incremental_loader.save_file_signature(file, data_id)
@ -155,7 +157,7 @@ async def ingest_data(
ext_metadata["incremental_processing"] = { ext_metadata["incremental_processing"] = {
"should_process": should_process, "should_process": should_process,
"change_info": change_info, "change_info": change_info,
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat())) "processing_timestamp": json.loads(json.dumps(datetime.now().isoformat())),
} }
if data_point is not None: if data_point is not None:

View file

@ -51,12 +51,12 @@ def test_AudioDocument(mock_engine):
GROUND_TRUTH, GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64), document.read(chunker_cls=TextChunker, max_chunk_size=64),
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -34,12 +34,12 @@ def test_ImageDocument(mock_engine):
GROUND_TRUTH, GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64), document.read(chunker_cls=TextChunker, max_chunk_size=64),
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -36,12 +36,12 @@ def test_PdfDocument(mock_engine):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024) GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -49,12 +49,12 @@ def test_TextDocument(mock_engine, input_file, chunk_size):
GROUND_TRUTH[input_file], GROUND_TRUTH[input_file],
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size), document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -79,32 +79,32 @@ def test_UnstructuredDocument(mock_engine):
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }" assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test DOCX # Test DOCX
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }" assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, ( assert (
f" sentence_end != {paragraph_data.cut_type = }" "sentence_end" == paragraph_data.cut_type
) ), f" sentence_end != {paragraph_data.cut_type = }"
# TEST CSV # TEST CSV
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }" assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( assert (
f"Read text doesn't match expected text: {paragraph_data.text}" "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
) ), f"Read text doesn't match expected text: {paragraph_data.text}"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test XLSX # Test XLSX
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }" assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"

View file

@ -12,9 +12,9 @@ async def check_graph_metrics_consistency_across_adapters(include_optional=False
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}") raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
for key, neo4j_value in neo4j_metrics.items(): for key, neo4j_value in neo4j_metrics.items():
assert networkx_metrics[key] == neo4j_value, ( assert (
f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx" networkx_metrics[key] == neo4j_value
) ), f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -71,6 +71,6 @@ async def assert_metrics(provider, include_optional=True):
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}") raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
for key, ground_truth_value in ground_truth_metrics.items(): for key, ground_truth_value in ground_truth_metrics.items():
assert metrics[key] == ground_truth_value, ( assert (
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}" metrics[key] == ground_truth_value
) ), f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"

View file

@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest() data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents # Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee # Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), ( assert not os.path.exists(
f"Data location still exists after deletion: {data.raw_data_location}" data.raw_data_location
) ), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session: async with engine.get_async_session() as session:
# Get data entry from database based on file path # Get data entry from database based on file path
data = ( data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location)) await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one() ).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted # Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), ( assert os.path.exists(
f"Data location doesn't exists: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1): async def test_getting_of_documents(dataset_name_1):
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 1" len(document_ids) == 1
) ), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided # Test getting of documents for search when no dataset is provided
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id) document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 2" len(document_ids) == 2
) ), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main(): async def main():

View file

@ -30,9 +30,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert result[0]["name"] == "Natural_language_processing_copy", ( assert (
"Result name does not match expected value." result[0]["name"] == "Natural_language_processing_copy"
) ), "Result name does not match expected value."
result = await relational_engine.get_all_data_from_table("datasets") result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found." assert len(result) == 2, "Unexpected number of datasets found."
@ -61,9 +61,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], ( assert (
"Content hash is not a part of file name." hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
) ), "Content hash is not a part of file name."
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)

View file

@ -92,9 +92,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), ( assert not os.path.exists(
"SQLite relational database is not empty" get_relational_engine().db_path
) ), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -103,13 +103,13 @@ async def main():
node_name=["nonexistent"], node_name=["nonexistent"],
).get_context("What is in the context?") ).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", ( assert (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" isinstance(context_nonempty, str) and context_nonempty != ""
) ), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
assert context_empty == "", ( assert (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" context_empty == ""
) ), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -107,13 +107,13 @@ async def main():
node_name=["nonexistent"], node_name=["nonexistent"],
).get_context("What is in the context?") ).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", ( assert (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" isinstance(context_nonempty, str) and context_nonempty != ""
) ), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
assert context_empty == "", ( assert (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" context_empty == ""
) ), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -23,28 +23,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest() data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents # Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee # Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), ( assert not os.path.exists(
f"Data location still exists after deletion: {data.raw_data_location}" data.raw_data_location
) ), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session: async with engine.get_async_session() as session:
# Get data entry from database based on file path # Get data entry from database based on file path
data = ( data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location)) await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one() ).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted # Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), ( assert os.path.exists(
f"Data location doesn't exists: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1): async def test_getting_of_documents(dataset_name_1):
@ -53,16 +53,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 1" len(document_ids) == 1
) ), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided # Test getting of documents for search when no dataset is provided
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id) document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 2" len(document_ids) == 2
) ), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main(): async def main():

View file

@ -112,9 +112,9 @@ async def relational_db_migration():
else: else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
assert len(distinct_node_names) == 12, ( assert (
f"Expected 12 distinct node references, found {len(distinct_node_names)}" len(distinct_node_names) == 12
) ), f"Expected 12 distinct node references, found {len(distinct_node_names)}"
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}" assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
expected_edges = { expected_edges = {

View file

@ -29,54 +29,54 @@ async def main():
logging.info(edge_type_counts) logging.info(edge_type_counts)
# Assert there is exactly one PdfDocument. # Assert there is exactly one PdfDocument.
assert type_counts.get("PdfDocument", 0) == 1, ( assert (
f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}" type_counts.get("PdfDocument", 0) == 1
) ), f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
# Assert there is exactly one TextDocument. # Assert there is exactly one TextDocument.
assert type_counts.get("TextDocument", 0) == 1, ( assert (
f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}" type_counts.get("TextDocument", 0) == 1
) ), f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
# Assert there are at least two DocumentChunk nodes. # Assert there are at least two DocumentChunk nodes.
assert type_counts.get("DocumentChunk", 0) >= 2, ( assert (
f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}" type_counts.get("DocumentChunk", 0) >= 2
) ), f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
# Assert there is at least two TextSummary. # Assert there is at least two TextSummary.
assert type_counts.get("TextSummary", 0) >= 2, ( assert (
f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}" type_counts.get("TextSummary", 0) >= 2
) ), f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
# Assert there is at least one Entity. # Assert there is at least one Entity.
assert type_counts.get("Entity", 0) > 0, ( assert (
f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}" type_counts.get("Entity", 0) > 0
) ), f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
# Assert there is at least one EntityType. # Assert there is at least one EntityType.
assert type_counts.get("EntityType", 0) > 0, ( assert (
f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}" type_counts.get("EntityType", 0) > 0
) ), f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
# Assert that there are at least two 'is_part_of' edges. # Assert that there are at least two 'is_part_of' edges.
assert edge_type_counts.get("is_part_of", 0) >= 2, ( assert (
f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}" edge_type_counts.get("is_part_of", 0) >= 2
) ), f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
# Assert that there are at least two 'made_from' edges. # Assert that there are at least two 'made_from' edges.
assert edge_type_counts.get("made_from", 0) >= 2, ( assert (
f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}" edge_type_counts.get("made_from", 0) >= 2
) ), f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
# Assert that there is at least one 'is_a' edge. # Assert that there is at least one 'is_a' edge.
assert edge_type_counts.get("is_a", 0) >= 1, ( assert (
f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}" edge_type_counts.get("is_a", 0) >= 1
) ), f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
# Assert that there is at least one 'contains' edge. # Assert that there is at least one 'contains' edge.
assert edge_type_counts.get("contains", 0) >= 1, ( assert (
f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}" edge_type_counts.get("contains", 0) >= 1
) ), f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -66,9 +66,9 @@ async def main():
assert isinstance(context, str), f"{name}: Context should be a string" assert isinstance(context, str), f"{name}: Context should be a string"
assert context.strip(), f"{name}: Context should not be empty" assert context.strip(), f"{name}: Context should not be empty"
lower = context.lower() lower = context.lower()
assert "germany" in lower or "netherlands" in lower, ( assert (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}" "germany" in lower or "netherlands" in lower
) ), f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
triplets_gk = await GraphCompletionRetriever().get_triplets( triplets_gk = await GraphCompletionRetriever().get_triplets(
query="Next to which country is Germany located?" query="Next to which country is Germany located?"
@ -96,18 +96,18 @@ async def main():
distance = edge.attributes.get("vector_distance") distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance") node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance") node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), ( assert isinstance(
f"{name}: vector_distance should be float, got {type(distance)}" distance, float
) ), f"{name}: vector_distance should be float, got {type(distance)}"
assert 0 <= distance <= 1, ( assert (
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen" 0 <= distance <= 1
) ), f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
assert 0 <= node1_distance <= 1, ( assert (
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen" 0 <= node1_distance <= 1
) ), f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
assert 0 <= node2_distance <= 1, ( assert (
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen" 0 <= node2_distance <= 1
) ), f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
completion_gk = await cognee.search( completion_gk = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_type=SearchType.GRAPH_COMPLETION,
@ -137,9 +137,9 @@ async def main():
text = completion[0] text = completion[0]
assert isinstance(text, str), f"{name}: element should be a string" assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty" assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), ( assert (
f"{name}: expected 'netherlands' in result, got: {text!r}" "netherlands" in text.lower()
) ), f"{name}: expected 'netherlands' in result, got: {text!r}"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -24,12 +24,12 @@ async def test_answer_generation():
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"]) mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
assert len(answers) == len(qa_pairs) assert len(answers) == len(qa_pairs)
assert answers[0]["question"] == qa_pairs[0]["question"], ( assert (
"AnswerGeneratorExecutor is passing the question incorrectly" answers[0]["question"] == qa_pairs[0]["question"]
) ), "AnswerGeneratorExecutor is passing the question incorrectly"
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], ( assert (
"AnswerGeneratorExecutor is passing the golden answer incorrectly" answers[0]["golden_answer"] == qa_pairs[0]["answer"]
) ), "AnswerGeneratorExecutor is passing the golden answer incorrectly"
assert answers[0]["answer"] == "Mocked answer", ( assert (
"AnswerGeneratorExecutor is passing the generated answer incorrectly" answers[0]["answer"] == "Mocked answer"
) ), "AnswerGeneratorExecutor is passing the generated answer incorrectly"

View file

@ -44,9 +44,9 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
corpus_list, qa_pairs = result corpus_list, qa_pairs = result
assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list." assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list."
assert isinstance(qa_pairs, list), ( assert isinstance(
f"{AdapterClass.__name__} question_answer_pairs is not a list." qa_pairs, list
) ), f"{AdapterClass.__name__} question_answer_pairs is not a list."
@pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES) @pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES)
@ -71,9 +71,9 @@ def test_adapter_returns_some_content(AdapterClass):
# We don't know how large the dataset is, but we expect at least 1 item # We don't know how large the dataset is, but we expect at least 1 item
assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list." assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list."
assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs." assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs."
assert len(qa_pairs) <= limit, ( assert (
f"{AdapterClass.__name__} returned more QA items than requested limit={limit}." len(qa_pairs) <= limit
) ), f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
for item in qa_pairs: for item in qa_pairs:
assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair." assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair."

View file

@ -12,9 +12,9 @@ def test_corpus_builder_load_corpus(benchmark):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default") corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit) raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}" assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
assert len(questions) <= 2, ( assert (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" len(questions) <= 2
) ), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -24,6 +24,6 @@ async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
limit = 2 limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default") corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit) questions = await corpus_builder.build_corpus(limit=limit)
assert len(questions) <= 2, ( assert (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" len(questions) <= 2
) ), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"

View file

@ -52,14 +52,14 @@ def test_metrics(metrics, actual, expected, expected_exact_score, expected_f1_ra
test_case = MockTestCase(actual, expected) test_case = MockTestCase(actual, expected)
exact_match_score = metrics["exact_match"].measure(test_case) exact_match_score = metrics["exact_match"].measure(test_case)
assert exact_match_score == expected_exact_score, ( assert (
f"Exact match failed for '{actual}' vs '{expected}'" exact_match_score == expected_exact_score
) ), f"Exact match failed for '{actual}' vs '{expected}'"
f1_score = metrics["f1"].measure(test_case) f1_score = metrics["f1"].measure(test_case)
assert expected_f1_range[0] <= f1_score <= expected_f1_range[1], ( assert (
f"F1 score failed for '{actual}' vs '{expected}'" expected_f1_range[0] <= f1_score <= expected_f1_range[1]
) ), f"F1 score failed for '{actual}' vs '{expected}'"
class TestBootstrapCI(unittest.TestCase): class TestBootstrapCI(unittest.TestCase):

View file

@ -157,15 +157,15 @@ def test_rate_limit_60_per_minute():
if len(failures) > 0: if len(failures) > 0:
first_failure_idx = int(failures[0].split()[1]) first_failure_idx = int(failures[0].split()[1])
print(f"First failure occurred at request index: {first_failure_idx}") print(f"First failure occurred at request index: {first_failure_idx}")
assert 58 <= first_failure_idx <= 62, ( assert (
f"Expected first failure around request #60, got #{first_failure_idx}" 58 <= first_failure_idx <= 62
) ), f"Expected first failure around request #60, got #{first_failure_idx}"
# Calculate requests per minute # Calculate requests per minute
rate_per_minute = len(successes) rate_per_minute = len(successes)
print(f"Rate: {rate_per_minute} requests per minute") print(f"Rate: {rate_per_minute} requests per minute")
# Verify the rate is close to our target of 60 requests per minute # Verify the rate is close to our target of 60 requests per minute
assert 58 <= rate_per_minute <= 62, ( assert (
f"Expected rate of ~60 requests per minute, got {rate_per_minute}" 58 <= rate_per_minute <= 62
) ), f"Expected rate of ~60 requests per minute, got {rate_per_minute}"

View file

@ -110,9 +110,9 @@ def test_sync_retry():
print(f"Number of attempts: {test_function_sync.counter}") print(f"Number of attempts: {test_function_sync.counter}")
# The function should succeed on the 3rd attempt (after 2 failures) # The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_sync.counter == 3, ( assert (
f"Expected 3 attempts, got {test_function_sync.counter}" test_function_sync.counter == 3
) ), f"Expected 3 attempts, got {test_function_sync.counter}"
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}" assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Synchronous retry mechanism is working correctly") print("✅ PASS: Synchronous retry mechanism is working correctly")
@ -143,9 +143,9 @@ async def test_async_retry():
print(f"Number of attempts: {test_function_async.counter}") print(f"Number of attempts: {test_function_async.counter}")
# The function should succeed on the 3rd attempt (after 2 failures) # The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_async.counter == 3, ( assert (
f"Expected 3 attempts, got {test_function_async.counter}" test_function_async.counter == 3
) ), f"Expected 3 attempts, got {test_function_async.counter}"
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}" assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Asynchronous retry mechanism is working correctly") print("✅ PASS: Asynchronous retry mechanism is working correctly")

View file

@ -57,9 +57,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Canva?") answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self): async def test_graph_completion_extension_context_complex(self):
@ -136,9 +136,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self): async def test_get_graph_completion_extension_context_on_empty_graph(self):
@ -167,9 +167,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -55,9 +55,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Canva?") answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self): async def test_graph_completion_cot_context_complex(self):
@ -134,9 +134,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self): async def test_get_graph_completion_cot_context_on_empty_graph(self):
@ -165,9 +165,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -24,9 +24,9 @@ max_chunk_size_vals = [512, 1024, 4096]
def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs): def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs) chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks]) reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -54,9 +54,9 @@ def test_paragraph_chunk_length(input_text, max_chunk_size, batch_paragraphs):
) )
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size] larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
assert np.all(chunk_lengths <= max_chunk_size), ( assert np.all(
f"{max_chunk_size = }: {larger_chunks} are too large" chunk_lengths <= max_chunk_size
) ), f"{max_chunk_size = }: {larger_chunks} are too large"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -76,6 +76,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, max_chunk_size, batch_pa
batch_paragraphs=batch_paragraphs, batch_paragraphs=batch_paragraphs,
) )
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( assert np.all(
f"{chunk_indices = } are not monotonically increasing" chunk_indices == np.arange(len(chunk_indices))
) ), f"{chunk_indices = } are not monotonically increasing"

View file

@ -71,9 +71,9 @@ def run_chunking_test(test_text, expected_chunks, mock_engine):
for expected_chunks_item, chunk in zip(expected_chunks, chunks): for expected_chunks_item, chunk in zip(expected_chunks, chunks):
for key in ["text", "chunk_size", "cut_type"]: for key in ["text", "chunk_size", "cut_type"]:
assert chunk[key] == expected_chunks_item[key], ( assert (
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }" chunk[key] == expected_chunks_item[key]
) ), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
def test_chunking_whole_text(): def test_chunking_whole_text():

View file

@ -17,9 +17,9 @@ maximum_length_vals = [None, 16, 64]
def test_chunk_by_sentence_isomorphism(input_text, maximum_length): def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
chunks = chunk_by_sentence(input_text, maximum_length) chunks = chunk_by_sentence(input_text, maximum_length)
reconstructed_text = "".join([chunk[1] for chunk in chunks]) reconstructed_text = "".join([chunk[1] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -40,9 +40,9 @@ def test_paragraph_chunk_length(input_text, maximum_length):
) )
larger_chunks = chunk_lengths[chunk_lengths > maximum_length] larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
assert np.all(chunk_lengths <= maximum_length), ( assert np.all(
f"{maximum_length = }: {larger_chunks} are too large" chunk_lengths <= maximum_length
) ), f"{maximum_length = }: {larger_chunks} are too large"
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS, INPUT_TE
def test_chunk_by_word_isomorphism(input_text): def test_chunk_by_word_isomorphism(input_text):
chunks = chunk_by_word(input_text) chunks = chunk_by_word(input_text)
reconstructed_text = "".join([chunk[0] for chunk in chunks]) reconstructed_text = "".join([chunk[0] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -66,8 +66,12 @@ This is the end of the modified content.
initial_signature = block_service.generate_signature(initial_file, "test_file.txt") initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
modified_signature = block_service.generate_signature(modified_file, "test_file.txt") modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks") print(
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks") f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
)
print(
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
)
# Compare signatures to find changes # Compare signatures to find changes
print("\n2. Comparing signatures to detect changes...") print("\n2. Comparing signatures to detect changes...")
@ -77,7 +81,9 @@ This is the end of the modified content.
print(f" Changed blocks: {changed_blocks}") print(f" Changed blocks: {changed_blocks}")
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}") print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}") print(
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
)
# Generate delta # Generate delta
print("\n3. Generating delta for changed content...") print("\n3. Generating delta for changed content...")
@ -145,7 +151,7 @@ async def demonstrate_with_cognee():
print("=" * 50) print("=" * 50)
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("Initial content for Cognee processing.") f.write("Initial content for Cognee processing.")
temp_file_path = f.name temp_file_path = f.name
@ -158,7 +164,7 @@ async def demonstrate_with_cognee():
print(" ✅ File added successfully") print(" ✅ File added successfully")
# Modify the file # Modify the file
with open(temp_file_path, 'w') as f: with open(temp_file_path, "w") as f:
f.write("Modified content for Cognee processing with additional text.") f.write("Modified content for Cognee processing with additional text.")
print("2. Adding modified version of the same file...") print("2. Adding modified version of the same file...")

View file

@ -4,7 +4,8 @@ Simple test for incremental loading functionality
""" """
import sys import sys
sys.path.insert(0, '.')
sys.path.insert(0, ".")
from io import BytesIO from io import BytesIO
from cognee.modules.ingestion.incremental import BlockHashService from cognee.modules.ingestion.incremental import BlockHashService
@ -47,8 +48,12 @@ Line 6: NEW - This is additional content"""
initial_signature = block_service.generate_signature(initial_file, "test_file.txt") initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
modified_signature = block_service.generate_signature(modified_file, "test_file.txt") modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks") print(
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks") f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
)
print(
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
)
# Compare signatures to find changes # Compare signatures to find changes
print("\n2. Comparing signatures to detect changes...") print("\n2. Comparing signatures to detect changes...")
@ -58,7 +63,9 @@ Line 6: NEW - This is additional content"""
print(f" Changed blocks: {changed_blocks}") print(f" Changed blocks: {changed_blocks}")
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}") print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}") print(
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
)
# Generate delta # Generate delta
print("\n3. Generating delta for changed content...") print("\n3. Generating delta for changed content...")

View file

@ -144,13 +144,13 @@ class TestIncrementalLoader:
block_size=10, block_size=10,
strong_len=8, strong_len=8,
blocks=blocks, blocks=blocks,
signature_data=b"signature" signature_data=b"signature",
) )
change_info = { change_info = {
"type": "incremental_changes", "type": "incremental_changes",
"changed_blocks": [1], # Only middle block changed "changed_blocks": [1], # Only middle block changed
"new_signature": signature "new_signature": signature,
} }
# This would normally be called after should_process_file # This would normally be called after should_process_file