reformat applied

This commit is contained in:
vasilije 2025-07-01 14:26:56 +02:00
parent 1ebeeac61d
commit 07f2afa69d
39 changed files with 577 additions and 531 deletions

View file

@ -5,43 +5,49 @@ Revises: 1d0bb7fede17
Create Date: 2025-01-27 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from uuid import uuid4
# revision identifiers, used by Alembic.
revision = 'incremental_file_signatures'
down_revision = '1d0bb7fede17'
revision = "incremental_file_signatures"
down_revision = "1d0bb7fede17"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('file_signatures',
sa.Column('id', sa.UUID(), nullable=False, default=uuid4),
sa.Column('data_id', sa.UUID(), nullable=True),
sa.Column('file_path', sa.String(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True),
sa.Column('content_hash', sa.String(), nullable=True),
sa.Column('total_blocks', sa.Integer(), nullable=True),
sa.Column('block_size', sa.Integer(), nullable=True),
sa.Column('strong_len', sa.Integer(), nullable=True),
sa.Column('signature_data', sa.LargeBinary(), nullable=True),
sa.Column('blocks_info', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"file_signatures",
sa.Column("id", sa.UUID(), nullable=False, default=uuid4),
sa.Column("data_id", sa.UUID(), nullable=True),
sa.Column("file_path", sa.String(), nullable=True),
sa.Column("file_size", sa.Integer(), nullable=True),
sa.Column("content_hash", sa.String(), nullable=True),
sa.Column("total_blocks", sa.Integer(), nullable=True),
sa.Column("block_size", sa.Integer(), nullable=True),
sa.Column("strong_len", sa.Integer(), nullable=True),
sa.Column("signature_data", sa.LargeBinary(), nullable=True),
sa.Column("blocks_info", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_file_signatures_data_id"), "file_signatures", ["data_id"], unique=False
)
op.create_index(
op.f("ix_file_signatures_content_hash"), "file_signatures", ["content_hash"], unique=False
)
op.create_index(op.f('ix_file_signatures_data_id'), 'file_signatures', ['data_id'], unique=False)
op.create_index(op.f('ix_file_signatures_content_hash'), 'file_signatures', ['content_hash'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_file_signatures_content_hash'), table_name='file_signatures')
op.drop_index(op.f('ix_file_signatures_data_id'), table_name='file_signatures')
op.drop_table('file_signatures')
# ### end Alembic commands ###
op.drop_index(op.f("ix_file_signatures_content_hash"), table_name="file_signatures")
op.drop_index(op.f("ix_file_signatures_data_id"), table_name="file_signatures")
op.drop_table("file_signatures")
# ### end Alembic commands ###

View file

@ -556,7 +556,7 @@ def log_database_configuration():
elif relational_config.db_provider == "sqlite":
logger.info(f"SQLite path: {relational_config.db_path}")
logger.info(f"SQLite database: {relational_config.db_name}")
# Log vector database configuration
vector_config = get_vectordb_config()
logger.info(f"Vector database: {vector_config.vector_db_provider}")
@ -564,7 +564,7 @@ def log_database_configuration():
logger.info(f"Vector database path: {vector_config.vector_db_url}")
elif vector_config.vector_db_provider in ["qdrant", "weaviate", "pgvector"]:
logger.info(f"Vector database URL: {vector_config.vector_db_url}")
# Log graph database configuration
graph_config = get_graph_config()
logger.info(f"Graph database: {graph_config.graph_database_provider}")
@ -572,7 +572,7 @@ def log_database_configuration():
logger.info(f"Graph database path: {graph_config.graph_file_path}")
elif graph_config.graph_database_provider in ["neo4j", "falkordb"]:
logger.info(f"Graph database URL: {graph_config.graph_database_url}")
except Exception as e:
logger.warning(f"Could not retrieve database configuration: {str(e)}")
@ -591,7 +591,7 @@ async def main():
# Log database configurations
log_database_configuration()
logger.info(f"Starting MCP server with transport: {args.transport}")
if args.transport == "stdio":
await mcp.run_stdio_async()

View file

@ -18,6 +18,7 @@ litellm.set_verbose = False
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("litellm").setLevel(logging.CRITICAL)
class GenericAPIAdapter(LLMInterface):
"""
Adapter for Generic API LLM provider API.

View file

@ -10,26 +10,28 @@ class FileSignature(Base):
__tablename__ = "file_signatures"
id = Column(UUID, primary_key=True, default=uuid4)
# Reference to the original data entry
data_id = Column(UUID, index=True)
# File information
file_path = Column(String)
file_size = Column(Integer)
content_hash = Column(String, index=True) # Overall file hash for quick comparison
# Block information
total_blocks = Column(Integer)
block_size = Column(Integer)
strong_len = Column(Integer)
# Signature data (binary)
signature_data = Column(LargeBinary)
# Block details (JSON array of block info)
blocks_info = Column(JSON) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
blocks_info = Column(
JSON
) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
@ -47,4 +49,4 @@ class FileSignature(Base):
"blocks_info": self.blocks_info,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
}

View file

@ -1,4 +1,4 @@
from .incremental_loader import IncrementalLoader
from .block_hash_service import BlockHashService
__all__ = ["IncrementalLoader", "BlockHashService"]
__all__ = ["IncrementalLoader", "BlockHashService"]

View file

@ -18,6 +18,7 @@ import tempfile
@dataclass
class BlockInfo:
"""Information about a file block"""
block_index: int
weak_checksum: int
strong_hash: str
@ -28,6 +29,7 @@ class BlockInfo:
@dataclass
class FileSignature:
"""File signature containing block information"""
file_path: str
file_size: int
total_blocks: int
@ -40,6 +42,7 @@ class FileSignature:
@dataclass
class FileDelta:
"""Delta information for changed blocks"""
changed_blocks: List[int] # Block indices that changed
delta_data: bytes
old_signature: FileSignature
@ -48,53 +51,51 @@ class FileDelta:
class BlockHashService:
"""Service for block-based file hashing using librsync algorithm"""
DEFAULT_BLOCK_SIZE = 1024 # 1KB blocks
DEFAULT_STRONG_LEN = 8 # 8 bytes for strong hash
def __init__(self, block_size: int = None, strong_len: int = None):
"""
Initialize the BlockHashService
Args:
block_size: Size of blocks in bytes (default: 1024)
strong_len: Length of strong hash in bytes (default: 8)
"""
self.block_size = block_size or self.DEFAULT_BLOCK_SIZE
self.strong_len = strong_len or self.DEFAULT_STRONG_LEN
def generate_signature(self, file_obj: BinaryIO, file_path: str = None) -> FileSignature:
"""
Generate a signature for a file using librsync algorithm
Args:
file_obj: File object to generate signature for
file_path: Optional file path for metadata
Returns:
FileSignature object containing block information
"""
file_obj.seek(0)
file_data = file_obj.read()
file_size = len(file_data)
# Calculate optimal signature parameters
magic, block_len, strong_len = get_signature_args(
file_size,
block_len=self.block_size,
strong_len=self.strong_len
file_size, block_len=self.block_size, strong_len=self.strong_len
)
# Generate signature using librsync
file_io = BytesIO(file_data)
sig_io = BytesIO()
signature(file_io, sig_io, strong_len, magic, block_len)
signature_data = sig_io.getvalue()
# Parse signature to extract block information
blocks = self._parse_signature(signature_data, file_data, block_len)
return FileSignature(
file_path=file_path or "",
file_size=file_size,
@ -102,52 +103,56 @@ class BlockHashService:
block_size=block_len,
strong_len=strong_len,
blocks=blocks,
signature_data=signature_data
signature_data=signature_data,
)
def _parse_signature(self, signature_data: bytes, file_data: bytes, block_size: int) -> List[BlockInfo]:
def _parse_signature(
self, signature_data: bytes, file_data: bytes, block_size: int
) -> List[BlockInfo]:
"""
Parse signature data to extract block information
Args:
signature_data: Raw signature data from librsync
file_data: Original file data
block_size: Size of blocks
Returns:
List of BlockInfo objects
"""
blocks = []
total_blocks = (len(file_data) + block_size - 1) // block_size
for i in range(total_blocks):
start_offset = i * block_size
end_offset = min(start_offset + block_size, len(file_data))
block_data = file_data[start_offset:end_offset]
# Calculate weak checksum (simple Adler-32 variant)
weak_checksum = self._calculate_weak_checksum(block_data)
# Calculate strong hash (MD5)
strong_hash = hashlib.md5(block_data).hexdigest()
blocks.append(BlockInfo(
block_index=i,
weak_checksum=weak_checksum,
strong_hash=strong_hash,
block_size=len(block_data),
file_offset=start_offset
))
blocks.append(
BlockInfo(
block_index=i,
weak_checksum=weak_checksum,
strong_hash=strong_hash,
block_size=len(block_data),
file_offset=start_offset,
)
)
return blocks
def _calculate_weak_checksum(self, data: bytes) -> int:
"""
Calculate a weak checksum similar to Adler-32
Args:
data: Block data
Returns:
Weak checksum value
"""
@ -157,111 +162,116 @@ class BlockHashService:
a = (a + byte) % 65521
b = (b + a) % 65521
return (b << 16) | a
def compare_signatures(self, old_sig: FileSignature, new_sig: FileSignature) -> List[int]:
"""
Compare two signatures to find changed blocks
Args:
old_sig: Previous file signature
new_sig: New file signature
Returns:
List of block indices that have changed
"""
changed_blocks = []
# Create lookup tables for efficient comparison
old_blocks = {block.block_index: block for block in old_sig.blocks}
new_blocks = {block.block_index: block for block in new_sig.blocks}
# Find changed, added, or removed blocks
all_indices = set(old_blocks.keys()) | set(new_blocks.keys())
for block_idx in all_indices:
old_block = old_blocks.get(block_idx)
new_block = new_blocks.get(block_idx)
if old_block is None or new_block is None:
# Block was added or removed
changed_blocks.append(block_idx)
elif (old_block.weak_checksum != new_block.weak_checksum or
old_block.strong_hash != new_block.strong_hash):
elif (
old_block.weak_checksum != new_block.weak_checksum
or old_block.strong_hash != new_block.strong_hash
):
# Block content changed
changed_blocks.append(block_idx)
return sorted(changed_blocks)
def generate_delta(self, old_file: BinaryIO, new_file: BinaryIO,
old_signature: FileSignature = None) -> FileDelta:
def generate_delta(
self, old_file: BinaryIO, new_file: BinaryIO, old_signature: FileSignature = None
) -> FileDelta:
"""
Generate a delta between two file versions
Args:
old_file: Previous version of the file
new_file: New version of the file
old_signature: Optional pre-computed signature of old file
Returns:
FileDelta object containing change information
"""
# Generate signatures if not provided
if old_signature is None:
old_signature = self.generate_signature(old_file)
new_signature = self.generate_signature(new_file)
# Generate delta using librsync
new_file.seek(0)
old_sig_io = BytesIO(old_signature.signature_data)
delta_io = BytesIO()
delta(new_file, old_sig_io, delta_io)
delta_data = delta_io.getvalue()
# Find changed blocks
changed_blocks = self.compare_signatures(old_signature, new_signature)
return FileDelta(
changed_blocks=changed_blocks,
delta_data=delta_data,
old_signature=old_signature,
new_signature=new_signature
new_signature=new_signature,
)
def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO:
"""
Apply a delta to reconstruct the new file
Args:
old_file: Original file
delta_obj: Delta information
Returns:
BytesIO object containing the reconstructed file
"""
old_file.seek(0)
delta_io = BytesIO(delta_obj.delta_data)
result_io = BytesIO()
patch(old_file, delta_io, result_io)
result_io.seek(0)
return result_io
def calculate_block_changes(self, old_sig: FileSignature, new_sig: FileSignature) -> Dict[str, Any]:
def calculate_block_changes(
self, old_sig: FileSignature, new_sig: FileSignature
) -> Dict[str, Any]:
"""
Calculate detailed statistics about block changes
Args:
old_sig: Previous file signature
new_sig: New file signature
Returns:
Dictionary with change statistics
"""
changed_blocks = self.compare_signatures(old_sig, new_sig)
return {
"total_old_blocks": len(old_sig.blocks),
"total_new_blocks": len(new_sig.blocks),
@ -271,4 +281,4 @@ class BlockHashService:
"compression_ratio": 1.0 - (len(changed_blocks) / max(len(old_sig.blocks), 1)),
"old_file_size": old_sig.file_size,
"new_file_size": new_sig.file_size,
}
}

View file

@ -26,99 +26,107 @@ class IncrementalLoader:
"""
Incremental file loader using rsync algorithm for efficient updates
"""
def __init__(self, block_size: int = 1024, strong_len: int = 8):
"""
Initialize the incremental loader
Args:
block_size: Size of blocks in bytes for rsync algorithm
strong_len: Length of strong hash in bytes
"""
self.block_service = BlockHashService(block_size, strong_len)
async def should_process_file(self, file_obj: BinaryIO, data_id: str) -> Tuple[bool, Optional[Dict]]:
async def should_process_file(
self, file_obj: BinaryIO, data_id: str
) -> Tuple[bool, Optional[Dict]]:
"""
Determine if a file should be processed based on incremental changes
Args:
file_obj: File object to check
data_id: Data ID for the file
Returns:
Tuple of (should_process, change_info)
- should_process: True if file needs processing
- change_info: Dictionary with change details if applicable
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
# Check if we have an existing signature for this file
existing_signature = await self._get_existing_signature(session, data_id)
if existing_signature is None:
# First time seeing this file, needs full processing
return True, {"type": "new_file", "full_processing": True}
# Generate signature for current file version
current_signature = self.block_service.generate_signature(file_obj)
# Quick check: if overall content hash is the same, no changes
file_obj.seek(0)
current_content_hash = get_file_content_hash(file_obj)
if current_content_hash == existing_signature.content_hash:
return False, {"type": "no_changes", "full_processing": False}
# Convert database signature to service signature for comparison
service_old_sig = self._db_signature_to_service(existing_signature)
# Compare signatures to find changed blocks
changed_blocks = self.block_service.compare_signatures(service_old_sig, current_signature)
changed_blocks = self.block_service.compare_signatures(
service_old_sig, current_signature
)
if not changed_blocks:
# Signatures match, no processing needed
return False, {"type": "no_changes", "full_processing": False}
# Calculate change statistics
change_stats = self.block_service.calculate_block_changes(service_old_sig, current_signature)
change_stats = self.block_service.calculate_block_changes(
service_old_sig, current_signature
)
change_info = {
"type": "incremental_changes",
"full_processing": len(changed_blocks) > (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
"full_processing": len(changed_blocks)
> (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
"changed_blocks": changed_blocks,
"stats": change_stats,
"new_signature": current_signature,
"old_signature": service_old_sig,
}
return True, change_info
async def process_incremental_changes(self, file_obj: BinaryIO, data_id: str,
change_info: Dict) -> List[Dict]:
async def process_incremental_changes(
self, file_obj: BinaryIO, data_id: str, change_info: Dict
) -> List[Dict]:
"""
Process only the changed blocks of a file
Args:
file_obj: File object to process
data_id: Data ID for the file
change_info: Change information from should_process_file
Returns:
List of block data that needs reprocessing
"""
if change_info["type"] != "incremental_changes":
raise ValueError("Invalid change_info type for incremental processing")
file_obj.seek(0)
file_data = file_obj.read()
changed_blocks = change_info["changed_blocks"]
new_signature = change_info["new_signature"]
# Extract data for changed blocks
changed_block_data = []
for block_idx in changed_blocks:
# Find the block info
block_info = None
@ -126,49 +134,51 @@ class IncrementalLoader:
if block.block_index == block_idx:
block_info = block
break
if block_info is None:
continue
# Extract block data
start_offset = block_info.file_offset
end_offset = start_offset + block_info.block_size
block_data = file_data[start_offset:end_offset]
changed_block_data.append({
"block_index": block_idx,
"block_data": block_data,
"block_info": block_info,
"file_offset": start_offset,
"block_size": len(block_data),
})
changed_block_data.append(
{
"block_index": block_idx,
"block_data": block_data,
"block_info": block_info,
"file_offset": start_offset,
"block_size": len(block_data),
}
)
return changed_block_data
async def save_file_signature(self, file_obj: BinaryIO, data_id: str) -> None:
"""
Save or update the file signature in the database
Args:
file_obj: File object
data_id: Data ID for the file
"""
# Generate signature
signature = self.block_service.generate_signature(file_obj, str(data_id))
# Calculate content hash
file_obj.seek(0)
content_hash = get_file_content_hash(file_obj)
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
# Check if signature already exists
existing = await session.execute(
select(FileSignature).filter(FileSignature.data_id == data_id)
)
existing_signature = existing.scalar_one_or_none()
# Prepare block info for JSON storage
blocks_info = [
{
@ -180,7 +190,7 @@ class IncrementalLoader:
}
for block in signature.blocks
]
if existing_signature:
# Update existing signature
existing_signature.file_path = signature.file_path
@ -191,7 +201,7 @@ class IncrementalLoader:
existing_signature.strong_len = signature.strong_len
existing_signature.signature_data = signature.signature_data
existing_signature.blocks_info = blocks_info
await session.merge(existing_signature)
else:
# Create new signature
@ -207,17 +217,19 @@ class IncrementalLoader:
blocks_info=blocks_info,
)
session.add(new_signature)
await session.commit()
async def _get_existing_signature(self, session: AsyncSession, data_id: str) -> Optional[FileSignature]:
async def _get_existing_signature(
self, session: AsyncSession, data_id: str
) -> Optional[FileSignature]:
"""
Get existing file signature from database
Args:
session: Database session
data_id: Data ID to search for
Returns:
FileSignature object or None if not found
"""
@ -225,19 +237,19 @@ class IncrementalLoader:
select(FileSignature).filter(FileSignature.data_id == data_id)
)
return result.scalar_one_or_none()
def _db_signature_to_service(self, db_signature: FileSignature) -> ServiceFileSignature:
"""
Convert database FileSignature to service FileSignature
Args:
db_signature: Database signature object
Returns:
Service FileSignature object
"""
from .block_hash_service import BlockInfo
# Convert blocks info
blocks = [
BlockInfo(
@ -249,7 +261,7 @@ class IncrementalLoader:
)
for block in db_signature.blocks_info
]
return ServiceFileSignature(
file_path=db_signature.file_path,
file_size=db_signature.file_size,
@ -259,26 +271,26 @@ class IncrementalLoader:
blocks=blocks,
signature_data=db_signature.signature_data,
)
async def cleanup_orphaned_signatures(self) -> int:
"""
Clean up file signatures that no longer have corresponding data entries
Returns:
Number of signatures removed
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
# Find signatures without corresponding data entries
orphaned_query = """
DELETE FROM file_signatures
WHERE data_id NOT IN (SELECT id FROM data)
"""
result = await session.execute(orphaned_query)
removed_count = result.rowcount
await session.commit()
return removed_count
return removed_count

View file

@ -109,13 +109,15 @@ async def ingest_data(
data_id = ingestion.identify(classified_data, user)
file_metadata = classified_data.get_metadata()
# Initialize incremental loader for this file
incremental_loader = IncrementalLoader()
# Check if file needs incremental processing
should_process, change_info = await incremental_loader.should_process_file(file, data_id)
should_process, change_info = await incremental_loader.should_process_file(
file, data_id
)
# Save updated file signature regardless of whether processing is needed
await incremental_loader.save_file_signature(file, data_id)
@ -150,12 +152,12 @@ async def ingest_data(
ext_metadata = get_external_metadata_dict(data_item)
if node_set:
ext_metadata["node_set"] = node_set
# Add incremental processing metadata
ext_metadata["incremental_processing"] = {
"should_process": should_process,
"change_info": change_info,
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat()))
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat())),
}
if data_point is not None:

View file

@ -51,12 +51,12 @@ def test_AudioDocument(mock_engine):
GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64),
):
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.chunk_size
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -34,12 +34,12 @@ def test_ImageDocument(mock_engine):
GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64),
):
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.chunk_size
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -36,12 +36,12 @@ def test_PdfDocument(mock_engine):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
):
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.chunk_size
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -49,12 +49,12 @@ def test_TextDocument(mock_engine, input_file, chunk_size):
GROUND_TRUTH[input_file],
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
):
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.chunk_size
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -79,32 +79,32 @@ def test_UnstructuredDocument(mock_engine):
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
# Test DOCX
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, (
f" sentence_end != {paragraph_data.cut_type = }"
)
assert (
"sentence_end" == paragraph_data.cut_type
), f" sentence_end != {paragraph_data.cut_type = }"
# TEST CSV
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
f"Read text doesn't match expected text: {paragraph_data.text}"
)
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
), f"Read text doesn't match expected text: {paragraph_data.text}"
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
# Test XLSX
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"

View file

@ -12,9 +12,9 @@ async def check_graph_metrics_consistency_across_adapters(include_optional=False
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
for key, neo4j_value in neo4j_metrics.items():
assert networkx_metrics[key] == neo4j_value, (
f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
)
assert (
networkx_metrics[key] == neo4j_value
), f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
if __name__ == "__main__":

View file

@ -71,6 +71,6 @@ async def assert_metrics(provider, include_optional=True):
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
for key, ground_truth_value in ground_truth_metrics.items():
assert metrics[key] == ground_truth_value, (
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
)
assert (
metrics[key] == ground_truth_value
), f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"

View file

@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), (
f"Data location still exists after deletion: {data.raw_data_location}"
)
assert not os.path.exists(
data.raw_data_location
), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session:
# Get data entry from database based on file path
data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one()
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), (
f"Data location doesn't exists: {data.raw_data_location}"
)
assert os.path.exists(
data.raw_data_location
), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1):
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, (
f"Number of expected documents doesn't match {len(document_ids)} != 1"
)
assert (
len(document_ids) == 1
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, (
f"Number of expected documents doesn't match {len(document_ids)} != 2"
)
assert (
len(document_ids) == 2
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main():

View file

@ -30,9 +30,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert result[0]["name"] == "Natural_language_processing_copy", (
"Result name does not match expected value."
)
assert (
result[0]["name"] == "Natural_language_processing_copy"
), "Result name does not match expected value."
result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found."
@ -61,9 +61,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
"Content hash is not a part of file name."
)
assert (
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
), "Content hash is not a part of file name."
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -92,9 +92,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), (
"SQLite relational database is not empty"
)
assert not os.path.exists(
get_relational_engine().db_path
), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -103,13 +103,13 @@ async def main():
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert (
isinstance(context_nonempty, str) and context_nonempty != ""
), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
assert context_empty == "", (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)
assert (
context_empty == ""
), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -107,13 +107,13 @@ async def main():
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert (
isinstance(context_nonempty, str) and context_nonempty != ""
), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
assert context_empty == "", (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)
assert (
context_empty == ""
), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -23,28 +23,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), (
f"Data location still exists after deletion: {data.raw_data_location}"
)
assert not os.path.exists(
data.raw_data_location
), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session:
# Get data entry from database based on file path
data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one()
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), (
f"Data location doesn't exists: {data.raw_data_location}"
)
assert os.path.exists(
data.raw_data_location
), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1):
@ -53,16 +53,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, (
f"Number of expected documents doesn't match {len(document_ids)} != 1"
)
assert (
len(document_ids) == 1
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, (
f"Number of expected documents doesn't match {len(document_ids)} != 2"
)
assert (
len(document_ids) == 2
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main():

View file

@ -112,9 +112,9 @@ async def relational_db_migration():
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
assert len(distinct_node_names) == 12, (
f"Expected 12 distinct node references, found {len(distinct_node_names)}"
)
assert (
len(distinct_node_names) == 12
), f"Expected 12 distinct node references, found {len(distinct_node_names)}"
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
expected_edges = {

View file

@ -29,54 +29,54 @@ async def main():
logging.info(edge_type_counts)
# Assert there is exactly one PdfDocument.
assert type_counts.get("PdfDocument", 0) == 1, (
f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
)
assert (
type_counts.get("PdfDocument", 0) == 1
), f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
# Assert there is exactly one TextDocument.
assert type_counts.get("TextDocument", 0) == 1, (
f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
)
assert (
type_counts.get("TextDocument", 0) == 1
), f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
# Assert there are at least two DocumentChunk nodes.
assert type_counts.get("DocumentChunk", 0) >= 2, (
f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
)
assert (
type_counts.get("DocumentChunk", 0) >= 2
), f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
# Assert there is at least two TextSummary.
assert type_counts.get("TextSummary", 0) >= 2, (
f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
)
assert (
type_counts.get("TextSummary", 0) >= 2
), f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
# Assert there is at least one Entity.
assert type_counts.get("Entity", 0) > 0, (
f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
)
assert (
type_counts.get("Entity", 0) > 0
), f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
# Assert there is at least one EntityType.
assert type_counts.get("EntityType", 0) > 0, (
f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
)
assert (
type_counts.get("EntityType", 0) > 0
), f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
# Assert that there are at least two 'is_part_of' edges.
assert edge_type_counts.get("is_part_of", 0) >= 2, (
f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
)
assert (
edge_type_counts.get("is_part_of", 0) >= 2
), f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
# Assert that there are at least two 'made_from' edges.
assert edge_type_counts.get("made_from", 0) >= 2, (
f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
)
assert (
edge_type_counts.get("made_from", 0) >= 2
), f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
# Assert that there is at least one 'is_a' edge.
assert edge_type_counts.get("is_a", 0) >= 1, (
f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
)
assert (
edge_type_counts.get("is_a", 0) >= 1
), f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
# Assert that there is at least one 'contains' edge.
assert edge_type_counts.get("contains", 0) >= 1, (
f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
)
assert (
edge_type_counts.get("contains", 0) >= 1
), f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
if __name__ == "__main__":

View file

@ -66,9 +66,9 @@ async def main():
assert isinstance(context, str), f"{name}: Context should be a string"
assert context.strip(), f"{name}: Context should not be empty"
lower = context.lower()
assert "germany" in lower or "netherlands" in lower, (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
)
assert (
"germany" in lower or "netherlands" in lower
), f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
triplets_gk = await GraphCompletionRetriever().get_triplets(
query="Next to which country is Germany located?"
@ -96,18 +96,18 @@ async def main():
distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), (
f"{name}: vector_distance should be float, got {type(distance)}"
)
assert 0 <= distance <= 1, (
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
)
assert 0 <= node1_distance <= 1, (
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
)
assert 0 <= node2_distance <= 1, (
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
)
assert isinstance(
distance, float
), f"{name}: vector_distance should be float, got {type(distance)}"
assert (
0 <= distance <= 1
), f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
assert (
0 <= node1_distance <= 1
), f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
assert (
0 <= node2_distance <= 1
), f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
completion_gk = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
@ -137,9 +137,9 @@ async def main():
text = completion[0]
assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
assert (
"netherlands" in text.lower()
), f"{name}: expected 'netherlands' in result, got: {text!r}"
if __name__ == "__main__":

View file

@ -24,12 +24,12 @@ async def test_answer_generation():
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
assert len(answers) == len(qa_pairs)
assert answers[0]["question"] == qa_pairs[0]["question"], (
"AnswerGeneratorExecutor is passing the question incorrectly"
)
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
)
assert answers[0]["answer"] == "Mocked answer", (
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
)
assert (
answers[0]["question"] == qa_pairs[0]["question"]
), "AnswerGeneratorExecutor is passing the question incorrectly"
assert (
answers[0]["golden_answer"] == qa_pairs[0]["answer"]
), "AnswerGeneratorExecutor is passing the golden answer incorrectly"
assert (
answers[0]["answer"] == "Mocked answer"
), "AnswerGeneratorExecutor is passing the generated answer incorrectly"

View file

@ -44,9 +44,9 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
corpus_list, qa_pairs = result
assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list."
assert isinstance(qa_pairs, list), (
f"{AdapterClass.__name__} question_answer_pairs is not a list."
)
assert isinstance(
qa_pairs, list
), f"{AdapterClass.__name__} question_answer_pairs is not a list."
@pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES)
@ -71,9 +71,9 @@ def test_adapter_returns_some_content(AdapterClass):
# We don't know how large the dataset is, but we expect at least 1 item
assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list."
assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs."
assert len(qa_pairs) <= limit, (
f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
)
assert (
len(qa_pairs) <= limit
), f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
for item in qa_pairs:
assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair."

View file

@ -12,9 +12,9 @@ def test_corpus_builder_load_corpus(benchmark):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
)
assert (
len(questions) <= 2
), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@pytest.mark.asyncio
@ -24,6 +24,6 @@ async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
)
assert (
len(questions) <= 2
), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"

View file

@ -52,14 +52,14 @@ def test_metrics(metrics, actual, expected, expected_exact_score, expected_f1_ra
test_case = MockTestCase(actual, expected)
exact_match_score = metrics["exact_match"].measure(test_case)
assert exact_match_score == expected_exact_score, (
f"Exact match failed for '{actual}' vs '{expected}'"
)
assert (
exact_match_score == expected_exact_score
), f"Exact match failed for '{actual}' vs '{expected}'"
f1_score = metrics["f1"].measure(test_case)
assert expected_f1_range[0] <= f1_score <= expected_f1_range[1], (
f"F1 score failed for '{actual}' vs '{expected}'"
)
assert (
expected_f1_range[0] <= f1_score <= expected_f1_range[1]
), f"F1 score failed for '{actual}' vs '{expected}'"
class TestBootstrapCI(unittest.TestCase):

View file

@ -157,15 +157,15 @@ def test_rate_limit_60_per_minute():
if len(failures) > 0:
first_failure_idx = int(failures[0].split()[1])
print(f"First failure occurred at request index: {first_failure_idx}")
assert 58 <= first_failure_idx <= 62, (
f"Expected first failure around request #60, got #{first_failure_idx}"
)
assert (
58 <= first_failure_idx <= 62
), f"Expected first failure around request #60, got #{first_failure_idx}"
# Calculate requests per minute
rate_per_minute = len(successes)
print(f"Rate: {rate_per_minute} requests per minute")
# Verify the rate is close to our target of 60 requests per minute
assert 58 <= rate_per_minute <= 62, (
f"Expected rate of ~60 requests per minute, got {rate_per_minute}"
)
assert (
58 <= rate_per_minute <= 62
), f"Expected rate of ~60 requests per minute, got {rate_per_minute}"

View file

@ -110,9 +110,9 @@ def test_sync_retry():
print(f"Number of attempts: {test_function_sync.counter}")
# The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_sync.counter == 3, (
f"Expected 3 attempts, got {test_function_sync.counter}"
)
assert (
test_function_sync.counter == 3
), f"Expected 3 attempts, got {test_function_sync.counter}"
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Synchronous retry mechanism is working correctly")
@ -143,9 +143,9 @@ async def test_async_retry():
print(f"Number of attempts: {test_function_async.counter}")
# The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_async.counter == 3, (
f"Expected 3 attempts, got {test_function_async.counter}"
)
assert (
test_function_async.counter == 3
), f"Expected 3 attempts, got {test_function_async.counter}"
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Asynchronous retry mechanism is working correctly")

View file

@ -57,9 +57,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert all(
isinstance(item, str) and item.strip() for item in answer
), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self):
@ -136,9 +136,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert all(
isinstance(item, str) and item.strip() for item in answer
), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self):
@ -167,9 +167,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert all(
isinstance(item, str) and item.strip() for item in answer
), "Answer must contain only non-empty strings"
if __name__ == "__main__":

View file

@ -55,9 +55,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert all(
isinstance(item, str) and item.strip() for item in answer
), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self):
@ -134,9 +134,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert all(
isinstance(item, str) and item.strip() for item in answer
), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self):
@ -165,9 +165,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert all(
isinstance(item, str) and item.strip() for item in answer
), "Answer must contain only non-empty strings"
if __name__ == "__main__":

View file

@ -24,9 +24,9 @@ max_chunk_size_vals = [512, 1024, 4096]
def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize(
@ -54,9 +54,9 @@ def test_paragraph_chunk_length(input_text, max_chunk_size, batch_paragraphs):
)
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
assert np.all(chunk_lengths <= max_chunk_size), (
f"{max_chunk_size = }: {larger_chunks} are too large"
)
assert np.all(
chunk_lengths <= max_chunk_size
), f"{max_chunk_size = }: {larger_chunks} are too large"
@pytest.mark.parametrize(
@ -76,6 +76,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, max_chunk_size, batch_pa
batch_paragraphs=batch_paragraphs,
)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
f"{chunk_indices = } are not monotonically increasing"
)
assert np.all(
chunk_indices == np.arange(len(chunk_indices))
), f"{chunk_indices = } are not monotonically increasing"

View file

@ -71,9 +71,9 @@ def run_chunking_test(test_text, expected_chunks, mock_engine):
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
for key in ["text", "chunk_size", "cut_type"]:
assert chunk[key] == expected_chunks_item[key], (
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
)
assert (
chunk[key] == expected_chunks_item[key]
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
def test_chunking_whole_text():

View file

@ -17,9 +17,9 @@ maximum_length_vals = [None, 16, 64]
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
chunks = chunk_by_sentence(input_text, maximum_length)
reconstructed_text = "".join([chunk[1] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize(
@ -40,9 +40,9 @@ def test_paragraph_chunk_length(input_text, maximum_length):
)
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
assert np.all(chunk_lengths <= maximum_length), (
f"{maximum_length = }: {larger_chunks} are too large"
)
assert np.all(
chunk_lengths <= maximum_length
), f"{maximum_length = }: {larger_chunks} are too large"
@pytest.mark.parametrize(

View file

@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS, INPUT_TE
def test_chunk_by_word_isomorphism(input_text):
chunks = chunk_by_word(input_text)
reconstructed_text = "".join([chunk[0] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize(

View file

@ -18,14 +18,14 @@ async def demonstrate_incremental_loading():
Demonstrate incremental file loading by creating a file, modifying it,
and showing how only changed blocks are detected.
"""
print("🚀 Cognee Incremental File Loading Demo")
print("=" * 50)
# Initialize the incremental loader
incremental_loader = IncrementalLoader(block_size=512) # 512 byte blocks for demo
block_service = BlockHashService(block_size=512)
# Create initial file content
initial_content = b"""
This is the initial content of our test file.
@ -40,7 +40,7 @@ Block 5: Excepteur sint occaecat cupidatat non proident, sunt in culpa.
This is the end of the initial content.
"""
# Create modified content (change Block 2 and add Block 6)
modified_content = b"""
This is the initial content of our test file.
@ -56,64 +56,70 @@ Block 6: NEW BLOCK - This is additional content that was added.
This is the end of the modified content.
"""
print("1. Creating signatures for initial and modified versions...")
# Generate signatures
initial_file = BytesIO(initial_content)
modified_file = BytesIO(modified_content)
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks")
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks")
print(
f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
)
print(
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
)
# Compare signatures to find changes
print("\n2. Comparing signatures to detect changes...")
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
print(f" Changed blocks: {changed_blocks}")
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}")
print(
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
)
# Generate delta
print("\n3. Generating delta for changed content...")
initial_file.seek(0)
modified_file.seek(0)
delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
print(f" Delta size: {len(delta.delta_data)} bytes")
print(f" Changed blocks in delta: {delta.changed_blocks}")
# Demonstrate reconstruction
print("\n4. Reconstructing file from delta...")
initial_file.seek(0)
reconstructed = block_service.apply_delta(initial_file, delta)
reconstructed_content = reconstructed.read()
print(f" Reconstruction successful: {reconstructed_content == modified_content}")
print(f" Reconstructed size: {len(reconstructed_content)} bytes")
# Show block details
print("\n5. Block-by-block analysis:")
print(" Block | Status | Strong Hash (first 8 chars)")
print(" ------|----------|---------------------------")
old_blocks = {b.block_index: b for b in initial_signature.blocks}
new_blocks = {b.block_index: b for b in modified_signature.blocks}
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
for idx in all_indices:
old_block = old_blocks.get(idx)
new_block = new_blocks.get(idx)
if old_block is None:
status = "ADDED"
hash_display = new_block.strong_hash[:8] if new_block else ""
@ -126,9 +132,9 @@ This is the end of the modified content.
else:
status = "MODIFIED"
hash_display = f"{old_block.strong_hash[:8]}{new_block.strong_hash[:8]}"
print(f" {idx:5d} | {status:8s} | {hash_display}")
print("\n✅ Incremental loading demo completed!")
print("\nThis demonstrates how Cognee can efficiently process only the changed")
print("parts of files, significantly reducing processing time for large files")
@ -139,35 +145,35 @@ async def demonstrate_with_cognee():
"""
Demonstrate integration with Cognee's add functionality
"""
print("\n" + "=" * 50)
print("🔧 Integration with Cognee Add Functionality")
print("=" * 50)
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("Initial content for Cognee processing.")
temp_file_path = f.name
try:
print(f"1. Adding initial file: {temp_file_path}")
# Add file to Cognee
await cognee.add(temp_file_path)
print(" ✅ File added successfully")
# Modify the file
with open(temp_file_path, 'w') as f:
with open(temp_file_path, "w") as f:
f.write("Modified content for Cognee processing with additional text.")
print("2. Adding modified version of the same file...")
# Add modified file - this should trigger incremental processing
await cognee.add(temp_file_path)
print(" ✅ Modified file processed with incremental loading")
finally:
# Clean up
if os.path.exists(temp_file_path):
@ -176,11 +182,11 @@ async def demonstrate_with_cognee():
if __name__ == "__main__":
import asyncio
print("Starting Cognee Incremental Loading Demo...")
# Run the demonstration
asyncio.run(demonstrate_incremental_loading())
# Uncomment the line below to test with actual Cognee integration
# asyncio.run(demonstrate_with_cognee())
# asyncio.run(demonstrate_with_cognee())

View file

@ -4,7 +4,8 @@ Simple test for incremental loading functionality
"""
import sys
sys.path.insert(0, '.')
sys.path.insert(0, ".")
from io import BytesIO
from cognee.modules.ingestion.incremental import BlockHashService
@ -14,13 +15,13 @@ def test_incremental_loading():
"""
Simple test of the incremental loading functionality
"""
print("🚀 Cognee Incremental File Loading Test")
print("=" * 50)
# Initialize the block service
block_service = BlockHashService(block_size=64) # Small blocks for demo
# Create initial file content
initial_content = b"""This is the initial content.
Line 1: Lorem ipsum dolor sit amet
@ -28,7 +29,7 @@ Line 2: Consectetur adipiscing elit
Line 3: Sed do eiusmod tempor
Line 4: Incididunt ut labore et dolore
Line 5: End of initial content"""
# Create modified content (change Line 2 and add Line 6)
modified_content = b"""This is the initial content.
Line 1: Lorem ipsum dolor sit amet
@ -37,64 +38,70 @@ Line 3: Sed do eiusmod tempor
Line 4: Incididunt ut labore et dolore
Line 5: End of initial content
Line 6: NEW - This is additional content"""
print("1. Creating signatures for initial and modified versions...")
# Generate signatures
initial_file = BytesIO(initial_content)
modified_file = BytesIO(modified_content)
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks")
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks")
print(
f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
)
print(
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
)
# Compare signatures to find changes
print("\n2. Comparing signatures to detect changes...")
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
print(f" Changed blocks: {changed_blocks}")
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}")
print(
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
)
# Generate delta
print("\n3. Generating delta for changed content...")
initial_file.seek(0)
modified_file.seek(0)
delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
print(f" Delta size: {len(delta.delta_data)} bytes")
print(f" Changed blocks in delta: {delta.changed_blocks}")
# Demonstrate reconstruction
print("\n4. Reconstructing file from delta...")
initial_file.seek(0)
reconstructed = block_service.apply_delta(initial_file, delta)
reconstructed_content = reconstructed.read()
print(f" Reconstruction successful: {reconstructed_content == modified_content}")
print(f" Reconstructed size: {len(reconstructed_content)} bytes")
# Show block details
print("\n5. Block-by-block analysis:")
print(" Block | Status | Strong Hash (first 8 chars)")
print(" ------|----------|---------------------------")
old_blocks = {b.block_index: b for b in initial_signature.blocks}
new_blocks = {b.block_index: b for b in modified_signature.blocks}
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
for idx in all_indices:
old_block = old_blocks.get(idx)
new_block = new_blocks.get(idx)
if old_block is None:
status = "ADDED"
hash_display = new_block.strong_hash[:8] if new_block else ""
@ -107,14 +114,14 @@ Line 6: NEW - This is additional content"""
else:
status = "MODIFIED"
hash_display = f"{old_block.strong_hash[:8]}{new_block.strong_hash[:8]}"
print(f" {idx:5d} | {status:8s} | {hash_display}")
print("\n✅ Incremental loading test completed!")
print("\nThis demonstrates how Cognee can efficiently process only the changed")
print("parts of files, significantly reducing processing time for large files")
print("with small modifications.")
return True
@ -124,4 +131,4 @@ if __name__ == "__main__":
print("\n🎉 Test passed successfully!")
else:
print("\n❌ Test failed!")
sys.exit(1)
sys.exit(1)

View file

@ -9,96 +9,96 @@ from cognee.modules.ingestion.incremental import BlockHashService, IncrementalLo
class TestBlockHashService:
"""Test the core block hashing service"""
def test_signature_generation(self):
"""Test basic signature generation"""
service = BlockHashService(block_size=10)
content = b"Hello, this is a test file for block hashing!"
file_obj = BytesIO(content)
signature = service.generate_signature(file_obj, "test.txt")
assert signature.file_path == "test.txt"
assert signature.file_size == len(content)
assert signature.block_size == 10
assert len(signature.blocks) > 0
assert signature.signature_data is not None
def test_change_detection(self):
"""Test detection of changes between file versions"""
service = BlockHashService(block_size=10)
# Original content
original_content = b"Hello, world! This is the original content."
original_file = BytesIO(original_content)
original_sig = service.generate_signature(original_file)
# Modified content (change in middle)
modified_content = b"Hello, world! This is the MODIFIED content."
modified_file = BytesIO(modified_content)
modified_sig = service.generate_signature(modified_file)
# Check for changes
changed_blocks = service.compare_signatures(original_sig, modified_sig)
assert len(changed_blocks) > 0 # Should detect changes
assert len(changed_blocks) < len(original_sig.blocks) # Not all blocks changed
def test_no_changes(self):
"""Test that identical files show no changes"""
service = BlockHashService(block_size=10)
content = b"This content will not change at all!"
file1 = BytesIO(content)
file2 = BytesIO(content)
sig1 = service.generate_signature(file1)
sig2 = service.generate_signature(file2)
changed_blocks = service.compare_signatures(sig1, sig2)
assert len(changed_blocks) == 0
def test_delta_generation(self):
"""Test delta generation and application"""
service = BlockHashService(block_size=8)
original_content = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
modified_content = b"ABCDEFGHXXXXXXXXXXXXXXWXYZ" # Change middle part
original_file = BytesIO(original_content)
modified_file = BytesIO(modified_content)
# Generate delta
delta = service.generate_delta(original_file, modified_file)
assert len(delta.changed_blocks) > 0
assert delta.delta_data is not None
# Apply delta
original_file.seek(0)
reconstructed = service.apply_delta(original_file, delta)
reconstructed_content = reconstructed.read()
assert reconstructed_content == modified_content
def test_block_statistics(self):
"""Test calculation of block change statistics"""
service = BlockHashService(block_size=5)
old_content = b"ABCDEFGHIJ" # 2 blocks
new_content = b"ABCDEFXXXX" # 2 blocks, second one changed
old_file = BytesIO(old_content)
new_file = BytesIO(new_content)
old_sig = service.generate_signature(old_file)
new_sig = service.generate_signature(new_file)
stats = service.calculate_block_changes(old_sig, new_sig)
assert stats["total_old_blocks"] == 2
assert stats["total_new_blocks"] == 2
assert stats["changed_blocks"] == 1 # Only second block changed
@ -107,36 +107,36 @@ class TestBlockHashService:
class TestIncrementalLoader:
"""Test the incremental loader integration"""
@pytest.mark.asyncio
async def test_should_process_new_file(self):
"""Test processing decision for new files"""
loader = IncrementalLoader()
content = b"This is a new file that hasn't been seen before."
file_obj = BytesIO(content)
# For a new file (no existing signature), should process
# Note: This test would need a mock database setup in real implementation
# For now, we test the logic without database interaction
pass # Placeholder for database-dependent test
def test_block_data_extraction(self):
"""Test extraction of changed block data"""
loader = IncrementalLoader(block_size=10)
content = b"Block1____Block2____Block3____"
file_obj = BytesIO(content)
# Create mock change info
from cognee.modules.ingestion.incremental.block_hash_service import BlockInfo, FileSignature
blocks = [
BlockInfo(0, 12345, "hash1", 10, 0),
BlockInfo(1, 23456, "hash2", 10, 10),
BlockInfo(2, 34567, "hash3", 10, 20),
]
signature = FileSignature(
file_path="test",
file_size=30,
@ -144,19 +144,19 @@ class TestIncrementalLoader:
block_size=10,
strong_len=8,
blocks=blocks,
signature_data=b"signature"
signature_data=b"signature",
)
change_info = {
"type": "incremental_changes",
"changed_blocks": [1], # Only middle block changed
"new_signature": signature
"new_signature": signature,
}
# This would normally be called after should_process_file
# Testing the block extraction logic
pass # Placeholder for full integration test
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__])