reformat applied

This commit is contained in:
vasilije 2025-07-01 14:26:56 +02:00
parent 1ebeeac61d
commit 07f2afa69d
39 changed files with 577 additions and 531 deletions

View file

@ -5,43 +5,49 @@ Revises: 1d0bb7fede17
Create Date: 2025-01-27 12:00:00.000000 Create Date: 2025-01-27 12:00:00.000000
""" """
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from uuid import uuid4 from uuid import uuid4
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'incremental_file_signatures' revision = "incremental_file_signatures"
down_revision = '1d0bb7fede17' down_revision = "1d0bb7fede17"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('file_signatures', op.create_table(
sa.Column('id', sa.UUID(), nullable=False, default=uuid4), "file_signatures",
sa.Column('data_id', sa.UUID(), nullable=True), sa.Column("id", sa.UUID(), nullable=False, default=uuid4),
sa.Column('file_path', sa.String(), nullable=True), sa.Column("data_id", sa.UUID(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True), sa.Column("file_path", sa.String(), nullable=True),
sa.Column('content_hash', sa.String(), nullable=True), sa.Column("file_size", sa.Integer(), nullable=True),
sa.Column('total_blocks', sa.Integer(), nullable=True), sa.Column("content_hash", sa.String(), nullable=True),
sa.Column('block_size', sa.Integer(), nullable=True), sa.Column("total_blocks", sa.Integer(), nullable=True),
sa.Column('strong_len', sa.Integer(), nullable=True), sa.Column("block_size", sa.Integer(), nullable=True),
sa.Column('signature_data', sa.LargeBinary(), nullable=True), sa.Column("strong_len", sa.Integer(), nullable=True),
sa.Column('blocks_info', sa.JSON(), nullable=True), sa.Column("signature_data", sa.LargeBinary(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), sa.Column("blocks_info", sa.JSON(), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_file_signatures_data_id"), "file_signatures", ["data_id"], unique=False
)
op.create_index(
op.f("ix_file_signatures_content_hash"), "file_signatures", ["content_hash"], unique=False
) )
op.create_index(op.f('ix_file_signatures_data_id'), 'file_signatures', ['data_id'], unique=False)
op.create_index(op.f('ix_file_signatures_content_hash'), 'file_signatures', ['content_hash'], unique=False)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_file_signatures_content_hash'), table_name='file_signatures') op.drop_index(op.f("ix_file_signatures_content_hash"), table_name="file_signatures")
op.drop_index(op.f('ix_file_signatures_data_id'), table_name='file_signatures') op.drop_index(op.f("ix_file_signatures_data_id"), table_name="file_signatures")
op.drop_table('file_signatures') op.drop_table("file_signatures")
# ### end Alembic commands ### # ### end Alembic commands ###

View file

@ -556,7 +556,7 @@ def log_database_configuration():
elif relational_config.db_provider == "sqlite": elif relational_config.db_provider == "sqlite":
logger.info(f"SQLite path: {relational_config.db_path}") logger.info(f"SQLite path: {relational_config.db_path}")
logger.info(f"SQLite database: {relational_config.db_name}") logger.info(f"SQLite database: {relational_config.db_name}")
# Log vector database configuration # Log vector database configuration
vector_config = get_vectordb_config() vector_config = get_vectordb_config()
logger.info(f"Vector database: {vector_config.vector_db_provider}") logger.info(f"Vector database: {vector_config.vector_db_provider}")
@ -564,7 +564,7 @@ def log_database_configuration():
logger.info(f"Vector database path: {vector_config.vector_db_url}") logger.info(f"Vector database path: {vector_config.vector_db_url}")
elif vector_config.vector_db_provider in ["qdrant", "weaviate", "pgvector"]: elif vector_config.vector_db_provider in ["qdrant", "weaviate", "pgvector"]:
logger.info(f"Vector database URL: {vector_config.vector_db_url}") logger.info(f"Vector database URL: {vector_config.vector_db_url}")
# Log graph database configuration # Log graph database configuration
graph_config = get_graph_config() graph_config = get_graph_config()
logger.info(f"Graph database: {graph_config.graph_database_provider}") logger.info(f"Graph database: {graph_config.graph_database_provider}")
@ -572,7 +572,7 @@ def log_database_configuration():
logger.info(f"Graph database path: {graph_config.graph_file_path}") logger.info(f"Graph database path: {graph_config.graph_file_path}")
elif graph_config.graph_database_provider in ["neo4j", "falkordb"]: elif graph_config.graph_database_provider in ["neo4j", "falkordb"]:
logger.info(f"Graph database URL: {graph_config.graph_database_url}") logger.info(f"Graph database URL: {graph_config.graph_database_url}")
except Exception as e: except Exception as e:
logger.warning(f"Could not retrieve database configuration: {str(e)}") logger.warning(f"Could not retrieve database configuration: {str(e)}")
@ -591,7 +591,7 @@ async def main():
# Log database configurations # Log database configurations
log_database_configuration() log_database_configuration()
logger.info(f"Starting MCP server with transport: {args.transport}") logger.info(f"Starting MCP server with transport: {args.transport}")
if args.transport == "stdio": if args.transport == "stdio":
await mcp.run_stdio_async() await mcp.run_stdio_async()

View file

@ -18,6 +18,7 @@ litellm.set_verbose = False
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL) logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("litellm").setLevel(logging.CRITICAL) logging.getLogger("litellm").setLevel(logging.CRITICAL)
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
""" """
Adapter for Generic API LLM provider API. Adapter for Generic API LLM provider API.

View file

@ -10,26 +10,28 @@ class FileSignature(Base):
__tablename__ = "file_signatures" __tablename__ = "file_signatures"
id = Column(UUID, primary_key=True, default=uuid4) id = Column(UUID, primary_key=True, default=uuid4)
# Reference to the original data entry # Reference to the original data entry
data_id = Column(UUID, index=True) data_id = Column(UUID, index=True)
# File information # File information
file_path = Column(String) file_path = Column(String)
file_size = Column(Integer) file_size = Column(Integer)
content_hash = Column(String, index=True) # Overall file hash for quick comparison content_hash = Column(String, index=True) # Overall file hash for quick comparison
# Block information # Block information
total_blocks = Column(Integer) total_blocks = Column(Integer)
block_size = Column(Integer) block_size = Column(Integer)
strong_len = Column(Integer) strong_len = Column(Integer)
# Signature data (binary) # Signature data (binary)
signature_data = Column(LargeBinary) signature_data = Column(LargeBinary)
# Block details (JSON array of block info) # Block details (JSON array of block info)
blocks_info = Column(JSON) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset} blocks_info = Column(
JSON
) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
# Timestamps # Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
@ -47,4 +49,4 @@ class FileSignature(Base):
"blocks_info": self.blocks_info, "blocks_info": self.blocks_info,
"created_at": self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat() if self.updated_at else None, "updated_at": self.updated_at.isoformat() if self.updated_at else None,
} }

View file

@ -1,4 +1,4 @@
from .incremental_loader import IncrementalLoader from .incremental_loader import IncrementalLoader
from .block_hash_service import BlockHashService from .block_hash_service import BlockHashService
__all__ = ["IncrementalLoader", "BlockHashService"] __all__ = ["IncrementalLoader", "BlockHashService"]

View file

@ -18,6 +18,7 @@ import tempfile
@dataclass @dataclass
class BlockInfo: class BlockInfo:
"""Information about a file block""" """Information about a file block"""
block_index: int block_index: int
weak_checksum: int weak_checksum: int
strong_hash: str strong_hash: str
@ -28,6 +29,7 @@ class BlockInfo:
@dataclass @dataclass
class FileSignature: class FileSignature:
"""File signature containing block information""" """File signature containing block information"""
file_path: str file_path: str
file_size: int file_size: int
total_blocks: int total_blocks: int
@ -40,6 +42,7 @@ class FileSignature:
@dataclass @dataclass
class FileDelta: class FileDelta:
"""Delta information for changed blocks""" """Delta information for changed blocks"""
changed_blocks: List[int] # Block indices that changed changed_blocks: List[int] # Block indices that changed
delta_data: bytes delta_data: bytes
old_signature: FileSignature old_signature: FileSignature
@ -48,53 +51,51 @@ class FileDelta:
class BlockHashService: class BlockHashService:
"""Service for block-based file hashing using librsync algorithm""" """Service for block-based file hashing using librsync algorithm"""
DEFAULT_BLOCK_SIZE = 1024 # 1KB blocks DEFAULT_BLOCK_SIZE = 1024 # 1KB blocks
DEFAULT_STRONG_LEN = 8 # 8 bytes for strong hash DEFAULT_STRONG_LEN = 8 # 8 bytes for strong hash
def __init__(self, block_size: int = None, strong_len: int = None): def __init__(self, block_size: int = None, strong_len: int = None):
""" """
Initialize the BlockHashService Initialize the BlockHashService
Args: Args:
block_size: Size of blocks in bytes (default: 1024) block_size: Size of blocks in bytes (default: 1024)
strong_len: Length of strong hash in bytes (default: 8) strong_len: Length of strong hash in bytes (default: 8)
""" """
self.block_size = block_size or self.DEFAULT_BLOCK_SIZE self.block_size = block_size or self.DEFAULT_BLOCK_SIZE
self.strong_len = strong_len or self.DEFAULT_STRONG_LEN self.strong_len = strong_len or self.DEFAULT_STRONG_LEN
def generate_signature(self, file_obj: BinaryIO, file_path: str = None) -> FileSignature: def generate_signature(self, file_obj: BinaryIO, file_path: str = None) -> FileSignature:
""" """
Generate a signature for a file using librsync algorithm Generate a signature for a file using librsync algorithm
Args: Args:
file_obj: File object to generate signature for file_obj: File object to generate signature for
file_path: Optional file path for metadata file_path: Optional file path for metadata
Returns: Returns:
FileSignature object containing block information FileSignature object containing block information
""" """
file_obj.seek(0) file_obj.seek(0)
file_data = file_obj.read() file_data = file_obj.read()
file_size = len(file_data) file_size = len(file_data)
# Calculate optimal signature parameters # Calculate optimal signature parameters
magic, block_len, strong_len = get_signature_args( magic, block_len, strong_len = get_signature_args(
file_size, file_size, block_len=self.block_size, strong_len=self.strong_len
block_len=self.block_size,
strong_len=self.strong_len
) )
# Generate signature using librsync # Generate signature using librsync
file_io = BytesIO(file_data) file_io = BytesIO(file_data)
sig_io = BytesIO() sig_io = BytesIO()
signature(file_io, sig_io, strong_len, magic, block_len) signature(file_io, sig_io, strong_len, magic, block_len)
signature_data = sig_io.getvalue() signature_data = sig_io.getvalue()
# Parse signature to extract block information # Parse signature to extract block information
blocks = self._parse_signature(signature_data, file_data, block_len) blocks = self._parse_signature(signature_data, file_data, block_len)
return FileSignature( return FileSignature(
file_path=file_path or "", file_path=file_path or "",
file_size=file_size, file_size=file_size,
@ -102,52 +103,56 @@ class BlockHashService:
block_size=block_len, block_size=block_len,
strong_len=strong_len, strong_len=strong_len,
blocks=blocks, blocks=blocks,
signature_data=signature_data signature_data=signature_data,
) )
def _parse_signature(self, signature_data: bytes, file_data: bytes, block_size: int) -> List[BlockInfo]: def _parse_signature(
self, signature_data: bytes, file_data: bytes, block_size: int
) -> List[BlockInfo]:
""" """
Parse signature data to extract block information Parse signature data to extract block information
Args: Args:
signature_data: Raw signature data from librsync signature_data: Raw signature data from librsync
file_data: Original file data file_data: Original file data
block_size: Size of blocks block_size: Size of blocks
Returns: Returns:
List of BlockInfo objects List of BlockInfo objects
""" """
blocks = [] blocks = []
total_blocks = (len(file_data) + block_size - 1) // block_size total_blocks = (len(file_data) + block_size - 1) // block_size
for i in range(total_blocks): for i in range(total_blocks):
start_offset = i * block_size start_offset = i * block_size
end_offset = min(start_offset + block_size, len(file_data)) end_offset = min(start_offset + block_size, len(file_data))
block_data = file_data[start_offset:end_offset] block_data = file_data[start_offset:end_offset]
# Calculate weak checksum (simple Adler-32 variant) # Calculate weak checksum (simple Adler-32 variant)
weak_checksum = self._calculate_weak_checksum(block_data) weak_checksum = self._calculate_weak_checksum(block_data)
# Calculate strong hash (MD5) # Calculate strong hash (MD5)
strong_hash = hashlib.md5(block_data).hexdigest() strong_hash = hashlib.md5(block_data).hexdigest()
blocks.append(BlockInfo( blocks.append(
block_index=i, BlockInfo(
weak_checksum=weak_checksum, block_index=i,
strong_hash=strong_hash, weak_checksum=weak_checksum,
block_size=len(block_data), strong_hash=strong_hash,
file_offset=start_offset block_size=len(block_data),
)) file_offset=start_offset,
)
)
return blocks return blocks
def _calculate_weak_checksum(self, data: bytes) -> int: def _calculate_weak_checksum(self, data: bytes) -> int:
""" """
Calculate a weak checksum similar to Adler-32 Calculate a weak checksum similar to Adler-32
Args: Args:
data: Block data data: Block data
Returns: Returns:
Weak checksum value Weak checksum value
""" """
@ -157,111 +162,116 @@ class BlockHashService:
a = (a + byte) % 65521 a = (a + byte) % 65521
b = (b + a) % 65521 b = (b + a) % 65521
return (b << 16) | a return (b << 16) | a
def compare_signatures(self, old_sig: FileSignature, new_sig: FileSignature) -> List[int]: def compare_signatures(self, old_sig: FileSignature, new_sig: FileSignature) -> List[int]:
""" """
Compare two signatures to find changed blocks Compare two signatures to find changed blocks
Args: Args:
old_sig: Previous file signature old_sig: Previous file signature
new_sig: New file signature new_sig: New file signature
Returns: Returns:
List of block indices that have changed List of block indices that have changed
""" """
changed_blocks = [] changed_blocks = []
# Create lookup tables for efficient comparison # Create lookup tables for efficient comparison
old_blocks = {block.block_index: block for block in old_sig.blocks} old_blocks = {block.block_index: block for block in old_sig.blocks}
new_blocks = {block.block_index: block for block in new_sig.blocks} new_blocks = {block.block_index: block for block in new_sig.blocks}
# Find changed, added, or removed blocks # Find changed, added, or removed blocks
all_indices = set(old_blocks.keys()) | set(new_blocks.keys()) all_indices = set(old_blocks.keys()) | set(new_blocks.keys())
for block_idx in all_indices: for block_idx in all_indices:
old_block = old_blocks.get(block_idx) old_block = old_blocks.get(block_idx)
new_block = new_blocks.get(block_idx) new_block = new_blocks.get(block_idx)
if old_block is None or new_block is None: if old_block is None or new_block is None:
# Block was added or removed # Block was added or removed
changed_blocks.append(block_idx) changed_blocks.append(block_idx)
elif (old_block.weak_checksum != new_block.weak_checksum or elif (
old_block.strong_hash != new_block.strong_hash): old_block.weak_checksum != new_block.weak_checksum
or old_block.strong_hash != new_block.strong_hash
):
# Block content changed # Block content changed
changed_blocks.append(block_idx) changed_blocks.append(block_idx)
return sorted(changed_blocks) return sorted(changed_blocks)
def generate_delta(self, old_file: BinaryIO, new_file: BinaryIO, def generate_delta(
old_signature: FileSignature = None) -> FileDelta: self, old_file: BinaryIO, new_file: BinaryIO, old_signature: FileSignature = None
) -> FileDelta:
""" """
Generate a delta between two file versions Generate a delta between two file versions
Args: Args:
old_file: Previous version of the file old_file: Previous version of the file
new_file: New version of the file new_file: New version of the file
old_signature: Optional pre-computed signature of old file old_signature: Optional pre-computed signature of old file
Returns: Returns:
FileDelta object containing change information FileDelta object containing change information
""" """
# Generate signatures if not provided # Generate signatures if not provided
if old_signature is None: if old_signature is None:
old_signature = self.generate_signature(old_file) old_signature = self.generate_signature(old_file)
new_signature = self.generate_signature(new_file) new_signature = self.generate_signature(new_file)
# Generate delta using librsync # Generate delta using librsync
new_file.seek(0) new_file.seek(0)
old_sig_io = BytesIO(old_signature.signature_data) old_sig_io = BytesIO(old_signature.signature_data)
delta_io = BytesIO() delta_io = BytesIO()
delta(new_file, old_sig_io, delta_io) delta(new_file, old_sig_io, delta_io)
delta_data = delta_io.getvalue() delta_data = delta_io.getvalue()
# Find changed blocks # Find changed blocks
changed_blocks = self.compare_signatures(old_signature, new_signature) changed_blocks = self.compare_signatures(old_signature, new_signature)
return FileDelta( return FileDelta(
changed_blocks=changed_blocks, changed_blocks=changed_blocks,
delta_data=delta_data, delta_data=delta_data,
old_signature=old_signature, old_signature=old_signature,
new_signature=new_signature new_signature=new_signature,
) )
def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO: def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO:
""" """
Apply a delta to reconstruct the new file Apply a delta to reconstruct the new file
Args: Args:
old_file: Original file old_file: Original file
delta_obj: Delta information delta_obj: Delta information
Returns: Returns:
BytesIO object containing the reconstructed file BytesIO object containing the reconstructed file
""" """
old_file.seek(0) old_file.seek(0)
delta_io = BytesIO(delta_obj.delta_data) delta_io = BytesIO(delta_obj.delta_data)
result_io = BytesIO() result_io = BytesIO()
patch(old_file, delta_io, result_io) patch(old_file, delta_io, result_io)
result_io.seek(0) result_io.seek(0)
return result_io return result_io
def calculate_block_changes(self, old_sig: FileSignature, new_sig: FileSignature) -> Dict[str, Any]: def calculate_block_changes(
self, old_sig: FileSignature, new_sig: FileSignature
) -> Dict[str, Any]:
""" """
Calculate detailed statistics about block changes Calculate detailed statistics about block changes
Args: Args:
old_sig: Previous file signature old_sig: Previous file signature
new_sig: New file signature new_sig: New file signature
Returns: Returns:
Dictionary with change statistics Dictionary with change statistics
""" """
changed_blocks = self.compare_signatures(old_sig, new_sig) changed_blocks = self.compare_signatures(old_sig, new_sig)
return { return {
"total_old_blocks": len(old_sig.blocks), "total_old_blocks": len(old_sig.blocks),
"total_new_blocks": len(new_sig.blocks), "total_new_blocks": len(new_sig.blocks),
@ -271,4 +281,4 @@ class BlockHashService:
"compression_ratio": 1.0 - (len(changed_blocks) / max(len(old_sig.blocks), 1)), "compression_ratio": 1.0 - (len(changed_blocks) / max(len(old_sig.blocks), 1)),
"old_file_size": old_sig.file_size, "old_file_size": old_sig.file_size,
"new_file_size": new_sig.file_size, "new_file_size": new_sig.file_size,
} }

View file

@ -26,99 +26,107 @@ class IncrementalLoader:
""" """
Incremental file loader using rsync algorithm for efficient updates Incremental file loader using rsync algorithm for efficient updates
""" """
def __init__(self, block_size: int = 1024, strong_len: int = 8): def __init__(self, block_size: int = 1024, strong_len: int = 8):
""" """
Initialize the incremental loader Initialize the incremental loader
Args: Args:
block_size: Size of blocks in bytes for rsync algorithm block_size: Size of blocks in bytes for rsync algorithm
strong_len: Length of strong hash in bytes strong_len: Length of strong hash in bytes
""" """
self.block_service = BlockHashService(block_size, strong_len) self.block_service = BlockHashService(block_size, strong_len)
async def should_process_file(self, file_obj: BinaryIO, data_id: str) -> Tuple[bool, Optional[Dict]]: async def should_process_file(
self, file_obj: BinaryIO, data_id: str
) -> Tuple[bool, Optional[Dict]]:
""" """
Determine if a file should be processed based on incremental changes Determine if a file should be processed based on incremental changes
Args: Args:
file_obj: File object to check file_obj: File object to check
data_id: Data ID for the file data_id: Data ID for the file
Returns: Returns:
Tuple of (should_process, change_info) Tuple of (should_process, change_info)
- should_process: True if file needs processing - should_process: True if file needs processing
- change_info: Dictionary with change details if applicable - change_info: Dictionary with change details if applicable
""" """
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# Check if we have an existing signature for this file # Check if we have an existing signature for this file
existing_signature = await self._get_existing_signature(session, data_id) existing_signature = await self._get_existing_signature(session, data_id)
if existing_signature is None: if existing_signature is None:
# First time seeing this file, needs full processing # First time seeing this file, needs full processing
return True, {"type": "new_file", "full_processing": True} return True, {"type": "new_file", "full_processing": True}
# Generate signature for current file version # Generate signature for current file version
current_signature = self.block_service.generate_signature(file_obj) current_signature = self.block_service.generate_signature(file_obj)
# Quick check: if overall content hash is the same, no changes # Quick check: if overall content hash is the same, no changes
file_obj.seek(0) file_obj.seek(0)
current_content_hash = get_file_content_hash(file_obj) current_content_hash = get_file_content_hash(file_obj)
if current_content_hash == existing_signature.content_hash: if current_content_hash == existing_signature.content_hash:
return False, {"type": "no_changes", "full_processing": False} return False, {"type": "no_changes", "full_processing": False}
# Convert database signature to service signature for comparison # Convert database signature to service signature for comparison
service_old_sig = self._db_signature_to_service(existing_signature) service_old_sig = self._db_signature_to_service(existing_signature)
# Compare signatures to find changed blocks # Compare signatures to find changed blocks
changed_blocks = self.block_service.compare_signatures(service_old_sig, current_signature) changed_blocks = self.block_service.compare_signatures(
service_old_sig, current_signature
)
if not changed_blocks: if not changed_blocks:
# Signatures match, no processing needed # Signatures match, no processing needed
return False, {"type": "no_changes", "full_processing": False} return False, {"type": "no_changes", "full_processing": False}
# Calculate change statistics # Calculate change statistics
change_stats = self.block_service.calculate_block_changes(service_old_sig, current_signature) change_stats = self.block_service.calculate_block_changes(
service_old_sig, current_signature
)
change_info = { change_info = {
"type": "incremental_changes", "type": "incremental_changes",
"full_processing": len(changed_blocks) > (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess "full_processing": len(changed_blocks)
> (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
"changed_blocks": changed_blocks, "changed_blocks": changed_blocks,
"stats": change_stats, "stats": change_stats,
"new_signature": current_signature, "new_signature": current_signature,
"old_signature": service_old_sig, "old_signature": service_old_sig,
} }
return True, change_info return True, change_info
async def process_incremental_changes(self, file_obj: BinaryIO, data_id: str, async def process_incremental_changes(
change_info: Dict) -> List[Dict]: self, file_obj: BinaryIO, data_id: str, change_info: Dict
) -> List[Dict]:
""" """
Process only the changed blocks of a file Process only the changed blocks of a file
Args: Args:
file_obj: File object to process file_obj: File object to process
data_id: Data ID for the file data_id: Data ID for the file
change_info: Change information from should_process_file change_info: Change information from should_process_file
Returns: Returns:
List of block data that needs reprocessing List of block data that needs reprocessing
""" """
if change_info["type"] != "incremental_changes": if change_info["type"] != "incremental_changes":
raise ValueError("Invalid change_info type for incremental processing") raise ValueError("Invalid change_info type for incremental processing")
file_obj.seek(0) file_obj.seek(0)
file_data = file_obj.read() file_data = file_obj.read()
changed_blocks = change_info["changed_blocks"] changed_blocks = change_info["changed_blocks"]
new_signature = change_info["new_signature"] new_signature = change_info["new_signature"]
# Extract data for changed blocks # Extract data for changed blocks
changed_block_data = [] changed_block_data = []
for block_idx in changed_blocks: for block_idx in changed_blocks:
# Find the block info # Find the block info
block_info = None block_info = None
@ -126,49 +134,51 @@ class IncrementalLoader:
if block.block_index == block_idx: if block.block_index == block_idx:
block_info = block block_info = block
break break
if block_info is None: if block_info is None:
continue continue
# Extract block data # Extract block data
start_offset = block_info.file_offset start_offset = block_info.file_offset
end_offset = start_offset + block_info.block_size end_offset = start_offset + block_info.block_size
block_data = file_data[start_offset:end_offset] block_data = file_data[start_offset:end_offset]
changed_block_data.append({ changed_block_data.append(
"block_index": block_idx, {
"block_data": block_data, "block_index": block_idx,
"block_info": block_info, "block_data": block_data,
"file_offset": start_offset, "block_info": block_info,
"block_size": len(block_data), "file_offset": start_offset,
}) "block_size": len(block_data),
}
)
return changed_block_data return changed_block_data
async def save_file_signature(self, file_obj: BinaryIO, data_id: str) -> None: async def save_file_signature(self, file_obj: BinaryIO, data_id: str) -> None:
""" """
Save or update the file signature in the database Save or update the file signature in the database
Args: Args:
file_obj: File object file_obj: File object
data_id: Data ID for the file data_id: Data ID for the file
""" """
# Generate signature # Generate signature
signature = self.block_service.generate_signature(file_obj, str(data_id)) signature = self.block_service.generate_signature(file_obj, str(data_id))
# Calculate content hash # Calculate content hash
file_obj.seek(0) file_obj.seek(0)
content_hash = get_file_content_hash(file_obj) content_hash = get_file_content_hash(file_obj)
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# Check if signature already exists # Check if signature already exists
existing = await session.execute( existing = await session.execute(
select(FileSignature).filter(FileSignature.data_id == data_id) select(FileSignature).filter(FileSignature.data_id == data_id)
) )
existing_signature = existing.scalar_one_or_none() existing_signature = existing.scalar_one_or_none()
# Prepare block info for JSON storage # Prepare block info for JSON storage
blocks_info = [ blocks_info = [
{ {
@ -180,7 +190,7 @@ class IncrementalLoader:
} }
for block in signature.blocks for block in signature.blocks
] ]
if existing_signature: if existing_signature:
# Update existing signature # Update existing signature
existing_signature.file_path = signature.file_path existing_signature.file_path = signature.file_path
@ -191,7 +201,7 @@ class IncrementalLoader:
existing_signature.strong_len = signature.strong_len existing_signature.strong_len = signature.strong_len
existing_signature.signature_data = signature.signature_data existing_signature.signature_data = signature.signature_data
existing_signature.blocks_info = blocks_info existing_signature.blocks_info = blocks_info
await session.merge(existing_signature) await session.merge(existing_signature)
else: else:
# Create new signature # Create new signature
@ -207,17 +217,19 @@ class IncrementalLoader:
blocks_info=blocks_info, blocks_info=blocks_info,
) )
session.add(new_signature) session.add(new_signature)
await session.commit() await session.commit()
async def _get_existing_signature(self, session: AsyncSession, data_id: str) -> Optional[FileSignature]: async def _get_existing_signature(
self, session: AsyncSession, data_id: str
) -> Optional[FileSignature]:
""" """
Get existing file signature from database Get existing file signature from database
Args: Args:
session: Database session session: Database session
data_id: Data ID to search for data_id: Data ID to search for
Returns: Returns:
FileSignature object or None if not found FileSignature object or None if not found
""" """
@ -225,19 +237,19 @@ class IncrementalLoader:
select(FileSignature).filter(FileSignature.data_id == data_id) select(FileSignature).filter(FileSignature.data_id == data_id)
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
def _db_signature_to_service(self, db_signature: FileSignature) -> ServiceFileSignature: def _db_signature_to_service(self, db_signature: FileSignature) -> ServiceFileSignature:
""" """
Convert database FileSignature to service FileSignature Convert database FileSignature to service FileSignature
Args: Args:
db_signature: Database signature object db_signature: Database signature object
Returns: Returns:
Service FileSignature object Service FileSignature object
""" """
from .block_hash_service import BlockInfo from .block_hash_service import BlockInfo
# Convert blocks info # Convert blocks info
blocks = [ blocks = [
BlockInfo( BlockInfo(
@ -249,7 +261,7 @@ class IncrementalLoader:
) )
for block in db_signature.blocks_info for block in db_signature.blocks_info
] ]
return ServiceFileSignature( return ServiceFileSignature(
file_path=db_signature.file_path, file_path=db_signature.file_path,
file_size=db_signature.file_size, file_size=db_signature.file_size,
@ -259,26 +271,26 @@ class IncrementalLoader:
blocks=blocks, blocks=blocks,
signature_data=db_signature.signature_data, signature_data=db_signature.signature_data,
) )
async def cleanup_orphaned_signatures(self) -> int: async def cleanup_orphaned_signatures(self) -> int:
""" """
Clean up file signatures that no longer have corresponding data entries Clean up file signatures that no longer have corresponding data entries
Returns: Returns:
Number of signatures removed Number of signatures removed
""" """
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# Find signatures without corresponding data entries # Find signatures without corresponding data entries
orphaned_query = """ orphaned_query = """
DELETE FROM file_signatures DELETE FROM file_signatures
WHERE data_id NOT IN (SELECT id FROM data) WHERE data_id NOT IN (SELECT id FROM data)
""" """
result = await session.execute(orphaned_query) result = await session.execute(orphaned_query)
removed_count = result.rowcount removed_count = result.rowcount
await session.commit() await session.commit()
return removed_count return removed_count

View file

@ -109,13 +109,15 @@ async def ingest_data(
data_id = ingestion.identify(classified_data, user) data_id = ingestion.identify(classified_data, user)
file_metadata = classified_data.get_metadata() file_metadata = classified_data.get_metadata()
# Initialize incremental loader for this file # Initialize incremental loader for this file
incremental_loader = IncrementalLoader() incremental_loader = IncrementalLoader()
# Check if file needs incremental processing # Check if file needs incremental processing
should_process, change_info = await incremental_loader.should_process_file(file, data_id) should_process, change_info = await incremental_loader.should_process_file(
file, data_id
)
# Save updated file signature regardless of whether processing is needed # Save updated file signature regardless of whether processing is needed
await incremental_loader.save_file_signature(file, data_id) await incremental_loader.save_file_signature(file, data_id)
@ -150,12 +152,12 @@ async def ingest_data(
ext_metadata = get_external_metadata_dict(data_item) ext_metadata = get_external_metadata_dict(data_item)
if node_set: if node_set:
ext_metadata["node_set"] = node_set ext_metadata["node_set"] = node_set
# Add incremental processing metadata # Add incremental processing metadata
ext_metadata["incremental_processing"] = { ext_metadata["incremental_processing"] = {
"should_process": should_process, "should_process": should_process,
"change_info": change_info, "change_info": change_info,
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat())) "processing_timestamp": json.loads(json.dumps(datetime.now().isoformat())),
} }
if data_point is not None: if data_point is not None:

View file

@ -51,12 +51,12 @@ def test_AudioDocument(mock_engine):
GROUND_TRUTH, GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64), document.read(chunker_cls=TextChunker, max_chunk_size=64),
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -34,12 +34,12 @@ def test_ImageDocument(mock_engine):
GROUND_TRUTH, GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64), document.read(chunker_cls=TextChunker, max_chunk_size=64),
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -36,12 +36,12 @@ def test_PdfDocument(mock_engine):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024) GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -49,12 +49,12 @@ def test_TextDocument(mock_engine, input_file, chunk_size):
GROUND_TRUTH[input_file], GROUND_TRUTH[input_file],
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size), document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
): ):
assert ground_truth["word_count"] == paragraph_data.chunk_size, ( assert (
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' ground_truth["word_count"] == paragraph_data.chunk_size
) ), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
assert ground_truth["len_text"] == len(paragraph_data.text), ( assert ground_truth["len_text"] == len(
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' paragraph_data.text
) ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert ground_truth["cut_type"] == paragraph_data.cut_type, ( assert (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ground_truth["cut_type"] == paragraph_data.cut_type
) ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -79,32 +79,32 @@ def test_UnstructuredDocument(mock_engine):
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }" assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test DOCX # Test DOCX
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }" assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, ( assert (
f" sentence_end != {paragraph_data.cut_type = }" "sentence_end" == paragraph_data.cut_type
) ), f" sentence_end != {paragraph_data.cut_type = }"
# TEST CSV # TEST CSV
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }" assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( assert (
f"Read text doesn't match expected text: {paragraph_data.text}" "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
) ), f"Read text doesn't match expected text: {paragraph_data.text}"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test XLSX # Test XLSX
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024): for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }" assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert (
f" sentence_cut != {paragraph_data.cut_type = }" "sentence_cut" == paragraph_data.cut_type
) ), f" sentence_cut != {paragraph_data.cut_type = }"

View file

@ -12,9 +12,9 @@ async def check_graph_metrics_consistency_across_adapters(include_optional=False
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}") raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
for key, neo4j_value in neo4j_metrics.items(): for key, neo4j_value in neo4j_metrics.items():
assert networkx_metrics[key] == neo4j_value, ( assert (
f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx" networkx_metrics[key] == neo4j_value
) ), f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -71,6 +71,6 @@ async def assert_metrics(provider, include_optional=True):
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}") raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
for key, ground_truth_value in ground_truth_metrics.items(): for key, ground_truth_value in ground_truth_metrics.items():
assert metrics[key] == ground_truth_value, ( assert (
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}" metrics[key] == ground_truth_value
) ), f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"

View file

@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest() data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents # Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee # Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), ( assert not os.path.exists(
f"Data location still exists after deletion: {data.raw_data_location}" data.raw_data_location
) ), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session: async with engine.get_async_session() as session:
# Get data entry from database based on file path # Get data entry from database based on file path
data = ( data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location)) await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one() ).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted # Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), ( assert os.path.exists(
f"Data location doesn't exists: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1): async def test_getting_of_documents(dataset_name_1):
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 1" len(document_ids) == 1
) ), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided # Test getting of documents for search when no dataset is provided
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id) document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 2" len(document_ids) == 2
) ), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main(): async def main():

View file

@ -30,9 +30,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert result[0]["name"] == "Natural_language_processing_copy", ( assert (
"Result name does not match expected value." result[0]["name"] == "Natural_language_processing_copy"
) ), "Result name does not match expected value."
result = await relational_engine.get_all_data_from_table("datasets") result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found." assert len(result) == 2, "Unexpected number of datasets found."
@ -61,9 +61,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], ( assert (
"Content hash is not a part of file name." hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
) ), "Content hash is not a part of file name."
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)

View file

@ -92,9 +92,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), ( assert not os.path.exists(
"SQLite relational database is not empty" get_relational_engine().db_path
) ), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -103,13 +103,13 @@ async def main():
node_name=["nonexistent"], node_name=["nonexistent"],
).get_context("What is in the context?") ).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", ( assert (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" isinstance(context_nonempty, str) and context_nonempty != ""
) ), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
assert context_empty == "", ( assert (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" context_empty == ""
) ), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -107,13 +107,13 @@ async def main():
node_name=["nonexistent"], node_name=["nonexistent"],
).get_context("What is in the context?") ).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", ( assert (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" isinstance(context_nonempty, str) and context_nonempty != ""
) ), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
assert context_empty == "", ( assert (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" context_empty == ""
) ), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -23,28 +23,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest() data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents # Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee # Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), ( assert not os.path.exists(
f"Data location still exists after deletion: {data.raw_data_location}" data.raw_data_location
) ), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session: async with engine.get_async_session() as session:
# Get data entry from database based on file path # Get data entry from database based on file path
data = ( data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location)) await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one() ).one()
assert os.path.isfile(data.raw_data_location), ( assert os.path.isfile(
f"Data location doesn't exist: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted # Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), ( assert os.path.exists(
f"Data location doesn't exists: {data.raw_data_location}" data.raw_data_location
) ), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1): async def test_getting_of_documents(dataset_name_1):
@ -53,16 +53,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 1" len(document_ids) == 1
) ), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided # Test getting of documents for search when no dataset is provided
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id) document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, ( assert (
f"Number of expected documents doesn't match {len(document_ids)} != 2" len(document_ids) == 2
) ), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main(): async def main():

View file

@ -112,9 +112,9 @@ async def relational_db_migration():
else: else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
assert len(distinct_node_names) == 12, ( assert (
f"Expected 12 distinct node references, found {len(distinct_node_names)}" len(distinct_node_names) == 12
) ), f"Expected 12 distinct node references, found {len(distinct_node_names)}"
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}" assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
expected_edges = { expected_edges = {

View file

@ -29,54 +29,54 @@ async def main():
logging.info(edge_type_counts) logging.info(edge_type_counts)
# Assert there is exactly one PdfDocument. # Assert there is exactly one PdfDocument.
assert type_counts.get("PdfDocument", 0) == 1, ( assert (
f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}" type_counts.get("PdfDocument", 0) == 1
) ), f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
# Assert there is exactly one TextDocument. # Assert there is exactly one TextDocument.
assert type_counts.get("TextDocument", 0) == 1, ( assert (
f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}" type_counts.get("TextDocument", 0) == 1
) ), f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
# Assert there are at least two DocumentChunk nodes. # Assert there are at least two DocumentChunk nodes.
assert type_counts.get("DocumentChunk", 0) >= 2, ( assert (
f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}" type_counts.get("DocumentChunk", 0) >= 2
) ), f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
# Assert there is at least two TextSummary. # Assert there is at least two TextSummary.
assert type_counts.get("TextSummary", 0) >= 2, ( assert (
f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}" type_counts.get("TextSummary", 0) >= 2
) ), f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
# Assert there is at least one Entity. # Assert there is at least one Entity.
assert type_counts.get("Entity", 0) > 0, ( assert (
f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}" type_counts.get("Entity", 0) > 0
) ), f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
# Assert there is at least one EntityType. # Assert there is at least one EntityType.
assert type_counts.get("EntityType", 0) > 0, ( assert (
f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}" type_counts.get("EntityType", 0) > 0
) ), f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
# Assert that there are at least two 'is_part_of' edges. # Assert that there are at least two 'is_part_of' edges.
assert edge_type_counts.get("is_part_of", 0) >= 2, ( assert (
f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}" edge_type_counts.get("is_part_of", 0) >= 2
) ), f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
# Assert that there are at least two 'made_from' edges. # Assert that there are at least two 'made_from' edges.
assert edge_type_counts.get("made_from", 0) >= 2, ( assert (
f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}" edge_type_counts.get("made_from", 0) >= 2
) ), f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
# Assert that there is at least one 'is_a' edge. # Assert that there is at least one 'is_a' edge.
assert edge_type_counts.get("is_a", 0) >= 1, ( assert (
f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}" edge_type_counts.get("is_a", 0) >= 1
) ), f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
# Assert that there is at least one 'contains' edge. # Assert that there is at least one 'contains' edge.
assert edge_type_counts.get("contains", 0) >= 1, ( assert (
f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}" edge_type_counts.get("contains", 0) >= 1
) ), f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -66,9 +66,9 @@ async def main():
assert isinstance(context, str), f"{name}: Context should be a string" assert isinstance(context, str), f"{name}: Context should be a string"
assert context.strip(), f"{name}: Context should not be empty" assert context.strip(), f"{name}: Context should not be empty"
lower = context.lower() lower = context.lower()
assert "germany" in lower or "netherlands" in lower, ( assert (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}" "germany" in lower or "netherlands" in lower
) ), f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
triplets_gk = await GraphCompletionRetriever().get_triplets( triplets_gk = await GraphCompletionRetriever().get_triplets(
query="Next to which country is Germany located?" query="Next to which country is Germany located?"
@ -96,18 +96,18 @@ async def main():
distance = edge.attributes.get("vector_distance") distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance") node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance") node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), ( assert isinstance(
f"{name}: vector_distance should be float, got {type(distance)}" distance, float
) ), f"{name}: vector_distance should be float, got {type(distance)}"
assert 0 <= distance <= 1, ( assert (
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen" 0 <= distance <= 1
) ), f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
assert 0 <= node1_distance <= 1, ( assert (
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen" 0 <= node1_distance <= 1
) ), f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
assert 0 <= node2_distance <= 1, ( assert (
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen" 0 <= node2_distance <= 1
) ), f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
completion_gk = await cognee.search( completion_gk = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_type=SearchType.GRAPH_COMPLETION,
@ -137,9 +137,9 @@ async def main():
text = completion[0] text = completion[0]
assert isinstance(text, str), f"{name}: element should be a string" assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty" assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), ( assert (
f"{name}: expected 'netherlands' in result, got: {text!r}" "netherlands" in text.lower()
) ), f"{name}: expected 'netherlands' in result, got: {text!r}"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -24,12 +24,12 @@ async def test_answer_generation():
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"]) mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
assert len(answers) == len(qa_pairs) assert len(answers) == len(qa_pairs)
assert answers[0]["question"] == qa_pairs[0]["question"], ( assert (
"AnswerGeneratorExecutor is passing the question incorrectly" answers[0]["question"] == qa_pairs[0]["question"]
) ), "AnswerGeneratorExecutor is passing the question incorrectly"
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], ( assert (
"AnswerGeneratorExecutor is passing the golden answer incorrectly" answers[0]["golden_answer"] == qa_pairs[0]["answer"]
) ), "AnswerGeneratorExecutor is passing the golden answer incorrectly"
assert answers[0]["answer"] == "Mocked answer", ( assert (
"AnswerGeneratorExecutor is passing the generated answer incorrectly" answers[0]["answer"] == "Mocked answer"
) ), "AnswerGeneratorExecutor is passing the generated answer incorrectly"

View file

@ -44,9 +44,9 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
corpus_list, qa_pairs = result corpus_list, qa_pairs = result
assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list." assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list."
assert isinstance(qa_pairs, list), ( assert isinstance(
f"{AdapterClass.__name__} question_answer_pairs is not a list." qa_pairs, list
) ), f"{AdapterClass.__name__} question_answer_pairs is not a list."
@pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES) @pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES)
@ -71,9 +71,9 @@ def test_adapter_returns_some_content(AdapterClass):
# We don't know how large the dataset is, but we expect at least 1 item # We don't know how large the dataset is, but we expect at least 1 item
assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list." assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list."
assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs." assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs."
assert len(qa_pairs) <= limit, ( assert (
f"{AdapterClass.__name__} returned more QA items than requested limit={limit}." len(qa_pairs) <= limit
) ), f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
for item in qa_pairs: for item in qa_pairs:
assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair." assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair."

View file

@ -12,9 +12,9 @@ def test_corpus_builder_load_corpus(benchmark):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default") corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit) raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}" assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
assert len(questions) <= 2, ( assert (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" len(questions) <= 2
) ), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -24,6 +24,6 @@ async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
limit = 2 limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default") corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit) questions = await corpus_builder.build_corpus(limit=limit)
assert len(questions) <= 2, ( assert (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" len(questions) <= 2
) ), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"

View file

@ -52,14 +52,14 @@ def test_metrics(metrics, actual, expected, expected_exact_score, expected_f1_ra
test_case = MockTestCase(actual, expected) test_case = MockTestCase(actual, expected)
exact_match_score = metrics["exact_match"].measure(test_case) exact_match_score = metrics["exact_match"].measure(test_case)
assert exact_match_score == expected_exact_score, ( assert (
f"Exact match failed for '{actual}' vs '{expected}'" exact_match_score == expected_exact_score
) ), f"Exact match failed for '{actual}' vs '{expected}'"
f1_score = metrics["f1"].measure(test_case) f1_score = metrics["f1"].measure(test_case)
assert expected_f1_range[0] <= f1_score <= expected_f1_range[1], ( assert (
f"F1 score failed for '{actual}' vs '{expected}'" expected_f1_range[0] <= f1_score <= expected_f1_range[1]
) ), f"F1 score failed for '{actual}' vs '{expected}'"
class TestBootstrapCI(unittest.TestCase): class TestBootstrapCI(unittest.TestCase):

View file

@ -157,15 +157,15 @@ def test_rate_limit_60_per_minute():
if len(failures) > 0: if len(failures) > 0:
first_failure_idx = int(failures[0].split()[1]) first_failure_idx = int(failures[0].split()[1])
print(f"First failure occurred at request index: {first_failure_idx}") print(f"First failure occurred at request index: {first_failure_idx}")
assert 58 <= first_failure_idx <= 62, ( assert (
f"Expected first failure around request #60, got #{first_failure_idx}" 58 <= first_failure_idx <= 62
) ), f"Expected first failure around request #60, got #{first_failure_idx}"
# Calculate requests per minute # Calculate requests per minute
rate_per_minute = len(successes) rate_per_minute = len(successes)
print(f"Rate: {rate_per_minute} requests per minute") print(f"Rate: {rate_per_minute} requests per minute")
# Verify the rate is close to our target of 60 requests per minute # Verify the rate is close to our target of 60 requests per minute
assert 58 <= rate_per_minute <= 62, ( assert (
f"Expected rate of ~60 requests per minute, got {rate_per_minute}" 58 <= rate_per_minute <= 62
) ), f"Expected rate of ~60 requests per minute, got {rate_per_minute}"

View file

@ -110,9 +110,9 @@ def test_sync_retry():
print(f"Number of attempts: {test_function_sync.counter}") print(f"Number of attempts: {test_function_sync.counter}")
# The function should succeed on the 3rd attempt (after 2 failures) # The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_sync.counter == 3, ( assert (
f"Expected 3 attempts, got {test_function_sync.counter}" test_function_sync.counter == 3
) ), f"Expected 3 attempts, got {test_function_sync.counter}"
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}" assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Synchronous retry mechanism is working correctly") print("✅ PASS: Synchronous retry mechanism is working correctly")
@ -143,9 +143,9 @@ async def test_async_retry():
print(f"Number of attempts: {test_function_async.counter}") print(f"Number of attempts: {test_function_async.counter}")
# The function should succeed on the 3rd attempt (after 2 failures) # The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_async.counter == 3, ( assert (
f"Expected 3 attempts, got {test_function_async.counter}" test_function_async.counter == 3
) ), f"Expected 3 attempts, got {test_function_async.counter}"
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}" assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Asynchronous retry mechanism is working correctly") print("✅ PASS: Asynchronous retry mechanism is working correctly")

View file

@ -57,9 +57,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Canva?") answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self): async def test_graph_completion_extension_context_complex(self):
@ -136,9 +136,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self): async def test_get_graph_completion_extension_context_on_empty_graph(self):
@ -167,9 +167,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -55,9 +55,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Canva?") answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self): async def test_graph_completion_cot_context_complex(self):
@ -134,9 +134,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self): async def test_get_graph_completion_cot_context_on_empty_graph(self):
@ -165,9 +165,9 @@ class TestGraphCompletionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), ( assert all(
"Answer must contain only non-empty strings" isinstance(item, str) and item.strip() for item in answer
) ), "Answer must contain only non-empty strings"
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -24,9 +24,9 @@ max_chunk_size_vals = [512, 1024, 4096]
def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs): def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs) chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks]) reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -54,9 +54,9 @@ def test_paragraph_chunk_length(input_text, max_chunk_size, batch_paragraphs):
) )
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size] larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
assert np.all(chunk_lengths <= max_chunk_size), ( assert np.all(
f"{max_chunk_size = }: {larger_chunks} are too large" chunk_lengths <= max_chunk_size
) ), f"{max_chunk_size = }: {larger_chunks} are too large"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -76,6 +76,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, max_chunk_size, batch_pa
batch_paragraphs=batch_paragraphs, batch_paragraphs=batch_paragraphs,
) )
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( assert np.all(
f"{chunk_indices = } are not monotonically increasing" chunk_indices == np.arange(len(chunk_indices))
) ), f"{chunk_indices = } are not monotonically increasing"

View file

@ -71,9 +71,9 @@ def run_chunking_test(test_text, expected_chunks, mock_engine):
for expected_chunks_item, chunk in zip(expected_chunks, chunks): for expected_chunks_item, chunk in zip(expected_chunks, chunks):
for key in ["text", "chunk_size", "cut_type"]: for key in ["text", "chunk_size", "cut_type"]:
assert chunk[key] == expected_chunks_item[key], ( assert (
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }" chunk[key] == expected_chunks_item[key]
) ), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
def test_chunking_whole_text(): def test_chunking_whole_text():

View file

@ -17,9 +17,9 @@ maximum_length_vals = [None, 16, 64]
def test_chunk_by_sentence_isomorphism(input_text, maximum_length): def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
chunks = chunk_by_sentence(input_text, maximum_length) chunks = chunk_by_sentence(input_text, maximum_length)
reconstructed_text = "".join([chunk[1] for chunk in chunks]) reconstructed_text = "".join([chunk[1] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -40,9 +40,9 @@ def test_paragraph_chunk_length(input_text, maximum_length):
) )
larger_chunks = chunk_lengths[chunk_lengths > maximum_length] larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
assert np.all(chunk_lengths <= maximum_length), ( assert np.all(
f"{maximum_length = }: {larger_chunks} are too large" chunk_lengths <= maximum_length
) ), f"{maximum_length = }: {larger_chunks} are too large"
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS, INPUT_TE
def test_chunk_by_word_isomorphism(input_text): def test_chunk_by_word_isomorphism(input_text):
chunks = chunk_by_word(input_text) chunks = chunk_by_word(input_text)
reconstructed_text = "".join([chunk[0] for chunk in chunks]) reconstructed_text = "".join([chunk[0] for chunk in chunks])
assert reconstructed_text == input_text, ( assert (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" reconstructed_text == input_text
) ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -18,14 +18,14 @@ async def demonstrate_incremental_loading():
Demonstrate incremental file loading by creating a file, modifying it, Demonstrate incremental file loading by creating a file, modifying it,
and showing how only changed blocks are detected. and showing how only changed blocks are detected.
""" """
print("🚀 Cognee Incremental File Loading Demo") print("🚀 Cognee Incremental File Loading Demo")
print("=" * 50) print("=" * 50)
# Initialize the incremental loader # Initialize the incremental loader
incremental_loader = IncrementalLoader(block_size=512) # 512 byte blocks for demo incremental_loader = IncrementalLoader(block_size=512) # 512 byte blocks for demo
block_service = BlockHashService(block_size=512) block_service = BlockHashService(block_size=512)
# Create initial file content # Create initial file content
initial_content = b""" initial_content = b"""
This is the initial content of our test file. This is the initial content of our test file.
@ -40,7 +40,7 @@ Block 5: Excepteur sint occaecat cupidatat non proident, sunt in culpa.
This is the end of the initial content. This is the end of the initial content.
""" """
# Create modified content (change Block 2 and add Block 6) # Create modified content (change Block 2 and add Block 6)
modified_content = b""" modified_content = b"""
This is the initial content of our test file. This is the initial content of our test file.
@ -56,64 +56,70 @@ Block 6: NEW BLOCK - This is additional content that was added.
This is the end of the modified content. This is the end of the modified content.
""" """
print("1. Creating signatures for initial and modified versions...") print("1. Creating signatures for initial and modified versions...")
# Generate signatures # Generate signatures
initial_file = BytesIO(initial_content) initial_file = BytesIO(initial_content)
modified_file = BytesIO(modified_content) modified_file = BytesIO(modified_content)
initial_signature = block_service.generate_signature(initial_file, "test_file.txt") initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
modified_signature = block_service.generate_signature(modified_file, "test_file.txt") modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks") print(
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks") f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
)
print(
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
)
# Compare signatures to find changes # Compare signatures to find changes
print("\n2. Comparing signatures to detect changes...") print("\n2. Comparing signatures to detect changes...")
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature) changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature) change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
print(f" Changed blocks: {changed_blocks}") print(f" Changed blocks: {changed_blocks}")
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}") print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}") print(
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
)
# Generate delta # Generate delta
print("\n3. Generating delta for changed content...") print("\n3. Generating delta for changed content...")
initial_file.seek(0) initial_file.seek(0)
modified_file.seek(0) modified_file.seek(0)
delta = block_service.generate_delta(initial_file, modified_file, initial_signature) delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
print(f" Delta size: {len(delta.delta_data)} bytes") print(f" Delta size: {len(delta.delta_data)} bytes")
print(f" Changed blocks in delta: {delta.changed_blocks}") print(f" Changed blocks in delta: {delta.changed_blocks}")
# Demonstrate reconstruction # Demonstrate reconstruction
print("\n4. Reconstructing file from delta...") print("\n4. Reconstructing file from delta...")
initial_file.seek(0) initial_file.seek(0)
reconstructed = block_service.apply_delta(initial_file, delta) reconstructed = block_service.apply_delta(initial_file, delta)
reconstructed_content = reconstructed.read() reconstructed_content = reconstructed.read()
print(f" Reconstruction successful: {reconstructed_content == modified_content}") print(f" Reconstruction successful: {reconstructed_content == modified_content}")
print(f" Reconstructed size: {len(reconstructed_content)} bytes") print(f" Reconstructed size: {len(reconstructed_content)} bytes")
# Show block details # Show block details
print("\n5. Block-by-block analysis:") print("\n5. Block-by-block analysis:")
print(" Block | Status | Strong Hash (first 8 chars)") print(" Block | Status | Strong Hash (first 8 chars)")
print(" ------|----------|---------------------------") print(" ------|----------|---------------------------")
old_blocks = {b.block_index: b for b in initial_signature.blocks} old_blocks = {b.block_index: b for b in initial_signature.blocks}
new_blocks = {b.block_index: b for b in modified_signature.blocks} new_blocks = {b.block_index: b for b in modified_signature.blocks}
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys())) all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
for idx in all_indices: for idx in all_indices:
old_block = old_blocks.get(idx) old_block = old_blocks.get(idx)
new_block = new_blocks.get(idx) new_block = new_blocks.get(idx)
if old_block is None: if old_block is None:
status = "ADDED" status = "ADDED"
hash_display = new_block.strong_hash[:8] if new_block else "" hash_display = new_block.strong_hash[:8] if new_block else ""
@ -126,9 +132,9 @@ This is the end of the modified content.
else: else:
status = "MODIFIED" status = "MODIFIED"
hash_display = f"{old_block.strong_hash[:8]}{new_block.strong_hash[:8]}" hash_display = f"{old_block.strong_hash[:8]}{new_block.strong_hash[:8]}"
print(f" {idx:5d} | {status:8s} | {hash_display}") print(f" {idx:5d} | {status:8s} | {hash_display}")
print("\n✅ Incremental loading demo completed!") print("\n✅ Incremental loading demo completed!")
print("\nThis demonstrates how Cognee can efficiently process only the changed") print("\nThis demonstrates how Cognee can efficiently process only the changed")
print("parts of files, significantly reducing processing time for large files") print("parts of files, significantly reducing processing time for large files")
@ -139,35 +145,35 @@ async def demonstrate_with_cognee():
""" """
Demonstrate integration with Cognee's add functionality Demonstrate integration with Cognee's add functionality
""" """
print("\n" + "=" * 50) print("\n" + "=" * 50)
print("🔧 Integration with Cognee Add Functionality") print("🔧 Integration with Cognee Add Functionality")
print("=" * 50) print("=" * 50)
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("Initial content for Cognee processing.") f.write("Initial content for Cognee processing.")
temp_file_path = f.name temp_file_path = f.name
try: try:
print(f"1. Adding initial file: {temp_file_path}") print(f"1. Adding initial file: {temp_file_path}")
# Add file to Cognee # Add file to Cognee
await cognee.add(temp_file_path) await cognee.add(temp_file_path)
print(" ✅ File added successfully") print(" ✅ File added successfully")
# Modify the file # Modify the file
with open(temp_file_path, 'w') as f: with open(temp_file_path, "w") as f:
f.write("Modified content for Cognee processing with additional text.") f.write("Modified content for Cognee processing with additional text.")
print("2. Adding modified version of the same file...") print("2. Adding modified version of the same file...")
# Add modified file - this should trigger incremental processing # Add modified file - this should trigger incremental processing
await cognee.add(temp_file_path) await cognee.add(temp_file_path)
print(" ✅ Modified file processed with incremental loading") print(" ✅ Modified file processed with incremental loading")
finally: finally:
# Clean up # Clean up
if os.path.exists(temp_file_path): if os.path.exists(temp_file_path):
@ -176,11 +182,11 @@ async def demonstrate_with_cognee():
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
print("Starting Cognee Incremental Loading Demo...") print("Starting Cognee Incremental Loading Demo...")
# Run the demonstration # Run the demonstration
asyncio.run(demonstrate_incremental_loading()) asyncio.run(demonstrate_incremental_loading())
# Uncomment the line below to test with actual Cognee integration # Uncomment the line below to test with actual Cognee integration
# asyncio.run(demonstrate_with_cognee()) # asyncio.run(demonstrate_with_cognee())

View file

@ -4,7 +4,8 @@ Simple test for incremental loading functionality
""" """
import sys import sys
sys.path.insert(0, '.')
sys.path.insert(0, ".")
from io import BytesIO from io import BytesIO
from cognee.modules.ingestion.incremental import BlockHashService from cognee.modules.ingestion.incremental import BlockHashService
@ -14,13 +15,13 @@ def test_incremental_loading():
""" """
Simple test of the incremental loading functionality Simple test of the incremental loading functionality
""" """
print("🚀 Cognee Incremental File Loading Test") print("🚀 Cognee Incremental File Loading Test")
print("=" * 50) print("=" * 50)
# Initialize the block service # Initialize the block service
block_service = BlockHashService(block_size=64) # Small blocks for demo block_service = BlockHashService(block_size=64) # Small blocks for demo
# Create initial file content # Create initial file content
initial_content = b"""This is the initial content. initial_content = b"""This is the initial content.
Line 1: Lorem ipsum dolor sit amet Line 1: Lorem ipsum dolor sit amet
@ -28,7 +29,7 @@ Line 2: Consectetur adipiscing elit
Line 3: Sed do eiusmod tempor Line 3: Sed do eiusmod tempor
Line 4: Incididunt ut labore et dolore Line 4: Incididunt ut labore et dolore
Line 5: End of initial content""" Line 5: End of initial content"""
# Create modified content (change Line 2 and add Line 6) # Create modified content (change Line 2 and add Line 6)
modified_content = b"""This is the initial content. modified_content = b"""This is the initial content.
Line 1: Lorem ipsum dolor sit amet Line 1: Lorem ipsum dolor sit amet
@ -37,64 +38,70 @@ Line 3: Sed do eiusmod tempor
Line 4: Incididunt ut labore et dolore Line 4: Incididunt ut labore et dolore
Line 5: End of initial content Line 5: End of initial content
Line 6: NEW - This is additional content""" Line 6: NEW - This is additional content"""
print("1. Creating signatures for initial and modified versions...") print("1. Creating signatures for initial and modified versions...")
# Generate signatures # Generate signatures
initial_file = BytesIO(initial_content) initial_file = BytesIO(initial_content)
modified_file = BytesIO(modified_content) modified_file = BytesIO(modified_content)
initial_signature = block_service.generate_signature(initial_file, "test_file.txt") initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
modified_signature = block_service.generate_signature(modified_file, "test_file.txt") modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks") print(
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks") f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
)
print(
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
)
# Compare signatures to find changes # Compare signatures to find changes
print("\n2. Comparing signatures to detect changes...") print("\n2. Comparing signatures to detect changes...")
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature) changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature) change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
print(f" Changed blocks: {changed_blocks}") print(f" Changed blocks: {changed_blocks}")
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}") print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}") print(
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
)
# Generate delta # Generate delta
print("\n3. Generating delta for changed content...") print("\n3. Generating delta for changed content...")
initial_file.seek(0) initial_file.seek(0)
modified_file.seek(0) modified_file.seek(0)
delta = block_service.generate_delta(initial_file, modified_file, initial_signature) delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
print(f" Delta size: {len(delta.delta_data)} bytes") print(f" Delta size: {len(delta.delta_data)} bytes")
print(f" Changed blocks in delta: {delta.changed_blocks}") print(f" Changed blocks in delta: {delta.changed_blocks}")
# Demonstrate reconstruction # Demonstrate reconstruction
print("\n4. Reconstructing file from delta...") print("\n4. Reconstructing file from delta...")
initial_file.seek(0) initial_file.seek(0)
reconstructed = block_service.apply_delta(initial_file, delta) reconstructed = block_service.apply_delta(initial_file, delta)
reconstructed_content = reconstructed.read() reconstructed_content = reconstructed.read()
print(f" Reconstruction successful: {reconstructed_content == modified_content}") print(f" Reconstruction successful: {reconstructed_content == modified_content}")
print(f" Reconstructed size: {len(reconstructed_content)} bytes") print(f" Reconstructed size: {len(reconstructed_content)} bytes")
# Show block details # Show block details
print("\n5. Block-by-block analysis:") print("\n5. Block-by-block analysis:")
print(" Block | Status | Strong Hash (first 8 chars)") print(" Block | Status | Strong Hash (first 8 chars)")
print(" ------|----------|---------------------------") print(" ------|----------|---------------------------")
old_blocks = {b.block_index: b for b in initial_signature.blocks} old_blocks = {b.block_index: b for b in initial_signature.blocks}
new_blocks = {b.block_index: b for b in modified_signature.blocks} new_blocks = {b.block_index: b for b in modified_signature.blocks}
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys())) all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
for idx in all_indices: for idx in all_indices:
old_block = old_blocks.get(idx) old_block = old_blocks.get(idx)
new_block = new_blocks.get(idx) new_block = new_blocks.get(idx)
if old_block is None: if old_block is None:
status = "ADDED" status = "ADDED"
hash_display = new_block.strong_hash[:8] if new_block else "" hash_display = new_block.strong_hash[:8] if new_block else ""
@ -107,14 +114,14 @@ Line 6: NEW - This is additional content"""
else: else:
status = "MODIFIED" status = "MODIFIED"
hash_display = f"{old_block.strong_hash[:8]}{new_block.strong_hash[:8]}" hash_display = f"{old_block.strong_hash[:8]}{new_block.strong_hash[:8]}"
print(f" {idx:5d} | {status:8s} | {hash_display}") print(f" {idx:5d} | {status:8s} | {hash_display}")
print("\n✅ Incremental loading test completed!") print("\n✅ Incremental loading test completed!")
print("\nThis demonstrates how Cognee can efficiently process only the changed") print("\nThis demonstrates how Cognee can efficiently process only the changed")
print("parts of files, significantly reducing processing time for large files") print("parts of files, significantly reducing processing time for large files")
print("with small modifications.") print("with small modifications.")
return True return True
@ -124,4 +131,4 @@ if __name__ == "__main__":
print("\n🎉 Test passed successfully!") print("\n🎉 Test passed successfully!")
else: else:
print("\n❌ Test failed!") print("\n❌ Test failed!")
sys.exit(1) sys.exit(1)

View file

@ -9,96 +9,96 @@ from cognee.modules.ingestion.incremental import BlockHashService, IncrementalLo
class TestBlockHashService: class TestBlockHashService:
"""Test the core block hashing service""" """Test the core block hashing service"""
def test_signature_generation(self): def test_signature_generation(self):
"""Test basic signature generation""" """Test basic signature generation"""
service = BlockHashService(block_size=10) service = BlockHashService(block_size=10)
content = b"Hello, this is a test file for block hashing!" content = b"Hello, this is a test file for block hashing!"
file_obj = BytesIO(content) file_obj = BytesIO(content)
signature = service.generate_signature(file_obj, "test.txt") signature = service.generate_signature(file_obj, "test.txt")
assert signature.file_path == "test.txt" assert signature.file_path == "test.txt"
assert signature.file_size == len(content) assert signature.file_size == len(content)
assert signature.block_size == 10 assert signature.block_size == 10
assert len(signature.blocks) > 0 assert len(signature.blocks) > 0
assert signature.signature_data is not None assert signature.signature_data is not None
def test_change_detection(self): def test_change_detection(self):
"""Test detection of changes between file versions""" """Test detection of changes between file versions"""
service = BlockHashService(block_size=10) service = BlockHashService(block_size=10)
# Original content # Original content
original_content = b"Hello, world! This is the original content." original_content = b"Hello, world! This is the original content."
original_file = BytesIO(original_content) original_file = BytesIO(original_content)
original_sig = service.generate_signature(original_file) original_sig = service.generate_signature(original_file)
# Modified content (change in middle) # Modified content (change in middle)
modified_content = b"Hello, world! This is the MODIFIED content." modified_content = b"Hello, world! This is the MODIFIED content."
modified_file = BytesIO(modified_content) modified_file = BytesIO(modified_content)
modified_sig = service.generate_signature(modified_file) modified_sig = service.generate_signature(modified_file)
# Check for changes # Check for changes
changed_blocks = service.compare_signatures(original_sig, modified_sig) changed_blocks = service.compare_signatures(original_sig, modified_sig)
assert len(changed_blocks) > 0 # Should detect changes assert len(changed_blocks) > 0 # Should detect changes
assert len(changed_blocks) < len(original_sig.blocks) # Not all blocks changed assert len(changed_blocks) < len(original_sig.blocks) # Not all blocks changed
def test_no_changes(self): def test_no_changes(self):
"""Test that identical files show no changes""" """Test that identical files show no changes"""
service = BlockHashService(block_size=10) service = BlockHashService(block_size=10)
content = b"This content will not change at all!" content = b"This content will not change at all!"
file1 = BytesIO(content) file1 = BytesIO(content)
file2 = BytesIO(content) file2 = BytesIO(content)
sig1 = service.generate_signature(file1) sig1 = service.generate_signature(file1)
sig2 = service.generate_signature(file2) sig2 = service.generate_signature(file2)
changed_blocks = service.compare_signatures(sig1, sig2) changed_blocks = service.compare_signatures(sig1, sig2)
assert len(changed_blocks) == 0 assert len(changed_blocks) == 0
def test_delta_generation(self): def test_delta_generation(self):
"""Test delta generation and application""" """Test delta generation and application"""
service = BlockHashService(block_size=8) service = BlockHashService(block_size=8)
original_content = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" original_content = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
modified_content = b"ABCDEFGHXXXXXXXXXXXXXXWXYZ" # Change middle part modified_content = b"ABCDEFGHXXXXXXXXXXXXXXWXYZ" # Change middle part
original_file = BytesIO(original_content) original_file = BytesIO(original_content)
modified_file = BytesIO(modified_content) modified_file = BytesIO(modified_content)
# Generate delta # Generate delta
delta = service.generate_delta(original_file, modified_file) delta = service.generate_delta(original_file, modified_file)
assert len(delta.changed_blocks) > 0 assert len(delta.changed_blocks) > 0
assert delta.delta_data is not None assert delta.delta_data is not None
# Apply delta # Apply delta
original_file.seek(0) original_file.seek(0)
reconstructed = service.apply_delta(original_file, delta) reconstructed = service.apply_delta(original_file, delta)
reconstructed_content = reconstructed.read() reconstructed_content = reconstructed.read()
assert reconstructed_content == modified_content assert reconstructed_content == modified_content
def test_block_statistics(self): def test_block_statistics(self):
"""Test calculation of block change statistics""" """Test calculation of block change statistics"""
service = BlockHashService(block_size=5) service = BlockHashService(block_size=5)
old_content = b"ABCDEFGHIJ" # 2 blocks old_content = b"ABCDEFGHIJ" # 2 blocks
new_content = b"ABCDEFXXXX" # 2 blocks, second one changed new_content = b"ABCDEFXXXX" # 2 blocks, second one changed
old_file = BytesIO(old_content) old_file = BytesIO(old_content)
new_file = BytesIO(new_content) new_file = BytesIO(new_content)
old_sig = service.generate_signature(old_file) old_sig = service.generate_signature(old_file)
new_sig = service.generate_signature(new_file) new_sig = service.generate_signature(new_file)
stats = service.calculate_block_changes(old_sig, new_sig) stats = service.calculate_block_changes(old_sig, new_sig)
assert stats["total_old_blocks"] == 2 assert stats["total_old_blocks"] == 2
assert stats["total_new_blocks"] == 2 assert stats["total_new_blocks"] == 2
assert stats["changed_blocks"] == 1 # Only second block changed assert stats["changed_blocks"] == 1 # Only second block changed
@ -107,36 +107,36 @@ class TestBlockHashService:
class TestIncrementalLoader: class TestIncrementalLoader:
"""Test the incremental loader integration""" """Test the incremental loader integration"""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_should_process_new_file(self): async def test_should_process_new_file(self):
"""Test processing decision for new files""" """Test processing decision for new files"""
loader = IncrementalLoader() loader = IncrementalLoader()
content = b"This is a new file that hasn't been seen before." content = b"This is a new file that hasn't been seen before."
file_obj = BytesIO(content) file_obj = BytesIO(content)
# For a new file (no existing signature), should process # For a new file (no existing signature), should process
# Note: This test would need a mock database setup in real implementation # Note: This test would need a mock database setup in real implementation
# For now, we test the logic without database interaction # For now, we test the logic without database interaction
pass # Placeholder for database-dependent test pass # Placeholder for database-dependent test
def test_block_data_extraction(self): def test_block_data_extraction(self):
"""Test extraction of changed block data""" """Test extraction of changed block data"""
loader = IncrementalLoader(block_size=10) loader = IncrementalLoader(block_size=10)
content = b"Block1____Block2____Block3____" content = b"Block1____Block2____Block3____"
file_obj = BytesIO(content) file_obj = BytesIO(content)
# Create mock change info # Create mock change info
from cognee.modules.ingestion.incremental.block_hash_service import BlockInfo, FileSignature from cognee.modules.ingestion.incremental.block_hash_service import BlockInfo, FileSignature
blocks = [ blocks = [
BlockInfo(0, 12345, "hash1", 10, 0), BlockInfo(0, 12345, "hash1", 10, 0),
BlockInfo(1, 23456, "hash2", 10, 10), BlockInfo(1, 23456, "hash2", 10, 10),
BlockInfo(2, 34567, "hash3", 10, 20), BlockInfo(2, 34567, "hash3", 10, 20),
] ]
signature = FileSignature( signature = FileSignature(
file_path="test", file_path="test",
file_size=30, file_size=30,
@ -144,19 +144,19 @@ class TestIncrementalLoader:
block_size=10, block_size=10,
strong_len=8, strong_len=8,
blocks=blocks, blocks=blocks,
signature_data=b"signature" signature_data=b"signature",
) )
change_info = { change_info = {
"type": "incremental_changes", "type": "incremental_changes",
"changed_blocks": [1], # Only middle block changed "changed_blocks": [1], # Only middle block changed
"new_signature": signature "new_signature": signature,
} }
# This would normally be called after should_process_file # This would normally be called after should_process_file
# Testing the block extraction logic # Testing the block extraction logic
pass # Placeholder for full integration test pass # Placeholder for full integration test
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])