reformat applied
This commit is contained in:
parent
1ebeeac61d
commit
07f2afa69d
39 changed files with 577 additions and 531 deletions
|
|
@ -5,43 +5,49 @@ Revises: 1d0bb7fede17
|
||||||
Create Date: 2025-01-27 12:00:00.000000
|
Create Date: 2025-01-27 12:00:00.000000
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = 'incremental_file_signatures'
|
revision = "incremental_file_signatures"
|
||||||
down_revision = '1d0bb7fede17'
|
down_revision = "1d0bb7fede17"
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('file_signatures',
|
op.create_table(
|
||||||
sa.Column('id', sa.UUID(), nullable=False, default=uuid4),
|
"file_signatures",
|
||||||
sa.Column('data_id', sa.UUID(), nullable=True),
|
sa.Column("id", sa.UUID(), nullable=False, default=uuid4),
|
||||||
sa.Column('file_path', sa.String(), nullable=True),
|
sa.Column("data_id", sa.UUID(), nullable=True),
|
||||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
sa.Column("file_path", sa.String(), nullable=True),
|
||||||
sa.Column('content_hash', sa.String(), nullable=True),
|
sa.Column("file_size", sa.Integer(), nullable=True),
|
||||||
sa.Column('total_blocks', sa.Integer(), nullable=True),
|
sa.Column("content_hash", sa.String(), nullable=True),
|
||||||
sa.Column('block_size', sa.Integer(), nullable=True),
|
sa.Column("total_blocks", sa.Integer(), nullable=True),
|
||||||
sa.Column('strong_len', sa.Integer(), nullable=True),
|
sa.Column("block_size", sa.Integer(), nullable=True),
|
||||||
sa.Column('signature_data', sa.LargeBinary(), nullable=True),
|
sa.Column("strong_len", sa.Integer(), nullable=True),
|
||||||
sa.Column('blocks_info', sa.JSON(), nullable=True),
|
sa.Column("signature_data", sa.LargeBinary(), nullable=True),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
sa.Column("blocks_info", sa.JSON(), nullable=True),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_file_signatures_data_id"), "file_signatures", ["data_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_file_signatures_content_hash"), "file_signatures", ["content_hash"], unique=False
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_file_signatures_data_id'), 'file_signatures', ['data_id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_file_signatures_content_hash'), 'file_signatures', ['content_hash'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_index(op.f('ix_file_signatures_content_hash'), table_name='file_signatures')
|
op.drop_index(op.f("ix_file_signatures_content_hash"), table_name="file_signatures")
|
||||||
op.drop_index(op.f('ix_file_signatures_data_id'), table_name='file_signatures')
|
op.drop_index(op.f("ix_file_signatures_data_id"), table_name="file_signatures")
|
||||||
op.drop_table('file_signatures')
|
op.drop_table("file_signatures")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
|
||||||
|
|
@ -556,7 +556,7 @@ def log_database_configuration():
|
||||||
elif relational_config.db_provider == "sqlite":
|
elif relational_config.db_provider == "sqlite":
|
||||||
logger.info(f"SQLite path: {relational_config.db_path}")
|
logger.info(f"SQLite path: {relational_config.db_path}")
|
||||||
logger.info(f"SQLite database: {relational_config.db_name}")
|
logger.info(f"SQLite database: {relational_config.db_name}")
|
||||||
|
|
||||||
# Log vector database configuration
|
# Log vector database configuration
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
logger.info(f"Vector database: {vector_config.vector_db_provider}")
|
logger.info(f"Vector database: {vector_config.vector_db_provider}")
|
||||||
|
|
@ -564,7 +564,7 @@ def log_database_configuration():
|
||||||
logger.info(f"Vector database path: {vector_config.vector_db_url}")
|
logger.info(f"Vector database path: {vector_config.vector_db_url}")
|
||||||
elif vector_config.vector_db_provider in ["qdrant", "weaviate", "pgvector"]:
|
elif vector_config.vector_db_provider in ["qdrant", "weaviate", "pgvector"]:
|
||||||
logger.info(f"Vector database URL: {vector_config.vector_db_url}")
|
logger.info(f"Vector database URL: {vector_config.vector_db_url}")
|
||||||
|
|
||||||
# Log graph database configuration
|
# Log graph database configuration
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
logger.info(f"Graph database: {graph_config.graph_database_provider}")
|
logger.info(f"Graph database: {graph_config.graph_database_provider}")
|
||||||
|
|
@ -572,7 +572,7 @@ def log_database_configuration():
|
||||||
logger.info(f"Graph database path: {graph_config.graph_file_path}")
|
logger.info(f"Graph database path: {graph_config.graph_file_path}")
|
||||||
elif graph_config.graph_database_provider in ["neo4j", "falkordb"]:
|
elif graph_config.graph_database_provider in ["neo4j", "falkordb"]:
|
||||||
logger.info(f"Graph database URL: {graph_config.graph_database_url}")
|
logger.info(f"Graph database URL: {graph_config.graph_database_url}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not retrieve database configuration: {str(e)}")
|
logger.warning(f"Could not retrieve database configuration: {str(e)}")
|
||||||
|
|
||||||
|
|
@ -591,7 +591,7 @@ async def main():
|
||||||
|
|
||||||
# Log database configurations
|
# Log database configurations
|
||||||
log_database_configuration()
|
log_database_configuration()
|
||||||
|
|
||||||
logger.info(f"Starting MCP server with transport: {args.transport}")
|
logger.info(f"Starting MCP server with transport: {args.transport}")
|
||||||
if args.transport == "stdio":
|
if args.transport == "stdio":
|
||||||
await mcp.run_stdio_async()
|
await mcp.run_stdio_async()
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ litellm.set_verbose = False
|
||||||
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
||||||
logging.getLogger("litellm").setLevel(logging.CRITICAL)
|
logging.getLogger("litellm").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
Adapter for Generic API LLM provider API.
|
Adapter for Generic API LLM provider API.
|
||||||
|
|
|
||||||
|
|
@ -10,26 +10,28 @@ class FileSignature(Base):
|
||||||
__tablename__ = "file_signatures"
|
__tablename__ = "file_signatures"
|
||||||
|
|
||||||
id = Column(UUID, primary_key=True, default=uuid4)
|
id = Column(UUID, primary_key=True, default=uuid4)
|
||||||
|
|
||||||
# Reference to the original data entry
|
# Reference to the original data entry
|
||||||
data_id = Column(UUID, index=True)
|
data_id = Column(UUID, index=True)
|
||||||
|
|
||||||
# File information
|
# File information
|
||||||
file_path = Column(String)
|
file_path = Column(String)
|
||||||
file_size = Column(Integer)
|
file_size = Column(Integer)
|
||||||
content_hash = Column(String, index=True) # Overall file hash for quick comparison
|
content_hash = Column(String, index=True) # Overall file hash for quick comparison
|
||||||
|
|
||||||
# Block information
|
# Block information
|
||||||
total_blocks = Column(Integer)
|
total_blocks = Column(Integer)
|
||||||
block_size = Column(Integer)
|
block_size = Column(Integer)
|
||||||
strong_len = Column(Integer)
|
strong_len = Column(Integer)
|
||||||
|
|
||||||
# Signature data (binary)
|
# Signature data (binary)
|
||||||
signature_data = Column(LargeBinary)
|
signature_data = Column(LargeBinary)
|
||||||
|
|
||||||
# Block details (JSON array of block info)
|
# Block details (JSON array of block info)
|
||||||
blocks_info = Column(JSON) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
|
blocks_info = Column(
|
||||||
|
JSON
|
||||||
|
) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
@ -47,4 +49,4 @@ class FileSignature(Base):
|
||||||
"blocks_info": self.blocks_info,
|
"blocks_info": self.blocks_info,
|
||||||
"created_at": self.created_at.isoformat(),
|
"created_at": self.created_at.isoformat(),
|
||||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .incremental_loader import IncrementalLoader
|
from .incremental_loader import IncrementalLoader
|
||||||
from .block_hash_service import BlockHashService
|
from .block_hash_service import BlockHashService
|
||||||
|
|
||||||
__all__ = ["IncrementalLoader", "BlockHashService"]
|
__all__ = ["IncrementalLoader", "BlockHashService"]
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import tempfile
|
||||||
@dataclass
|
@dataclass
|
||||||
class BlockInfo:
|
class BlockInfo:
|
||||||
"""Information about a file block"""
|
"""Information about a file block"""
|
||||||
|
|
||||||
block_index: int
|
block_index: int
|
||||||
weak_checksum: int
|
weak_checksum: int
|
||||||
strong_hash: str
|
strong_hash: str
|
||||||
|
|
@ -28,6 +29,7 @@ class BlockInfo:
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileSignature:
|
class FileSignature:
|
||||||
"""File signature containing block information"""
|
"""File signature containing block information"""
|
||||||
|
|
||||||
file_path: str
|
file_path: str
|
||||||
file_size: int
|
file_size: int
|
||||||
total_blocks: int
|
total_blocks: int
|
||||||
|
|
@ -40,6 +42,7 @@ class FileSignature:
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileDelta:
|
class FileDelta:
|
||||||
"""Delta information for changed blocks"""
|
"""Delta information for changed blocks"""
|
||||||
|
|
||||||
changed_blocks: List[int] # Block indices that changed
|
changed_blocks: List[int] # Block indices that changed
|
||||||
delta_data: bytes
|
delta_data: bytes
|
||||||
old_signature: FileSignature
|
old_signature: FileSignature
|
||||||
|
|
@ -48,53 +51,51 @@ class FileDelta:
|
||||||
|
|
||||||
class BlockHashService:
|
class BlockHashService:
|
||||||
"""Service for block-based file hashing using librsync algorithm"""
|
"""Service for block-based file hashing using librsync algorithm"""
|
||||||
|
|
||||||
DEFAULT_BLOCK_SIZE = 1024 # 1KB blocks
|
DEFAULT_BLOCK_SIZE = 1024 # 1KB blocks
|
||||||
DEFAULT_STRONG_LEN = 8 # 8 bytes for strong hash
|
DEFAULT_STRONG_LEN = 8 # 8 bytes for strong hash
|
||||||
|
|
||||||
def __init__(self, block_size: int = None, strong_len: int = None):
|
def __init__(self, block_size: int = None, strong_len: int = None):
|
||||||
"""
|
"""
|
||||||
Initialize the BlockHashService
|
Initialize the BlockHashService
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_size: Size of blocks in bytes (default: 1024)
|
block_size: Size of blocks in bytes (default: 1024)
|
||||||
strong_len: Length of strong hash in bytes (default: 8)
|
strong_len: Length of strong hash in bytes (default: 8)
|
||||||
"""
|
"""
|
||||||
self.block_size = block_size or self.DEFAULT_BLOCK_SIZE
|
self.block_size = block_size or self.DEFAULT_BLOCK_SIZE
|
||||||
self.strong_len = strong_len or self.DEFAULT_STRONG_LEN
|
self.strong_len = strong_len or self.DEFAULT_STRONG_LEN
|
||||||
|
|
||||||
def generate_signature(self, file_obj: BinaryIO, file_path: str = None) -> FileSignature:
|
def generate_signature(self, file_obj: BinaryIO, file_path: str = None) -> FileSignature:
|
||||||
"""
|
"""
|
||||||
Generate a signature for a file using librsync algorithm
|
Generate a signature for a file using librsync algorithm
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_obj: File object to generate signature for
|
file_obj: File object to generate signature for
|
||||||
file_path: Optional file path for metadata
|
file_path: Optional file path for metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FileSignature object containing block information
|
FileSignature object containing block information
|
||||||
"""
|
"""
|
||||||
file_obj.seek(0)
|
file_obj.seek(0)
|
||||||
file_data = file_obj.read()
|
file_data = file_obj.read()
|
||||||
file_size = len(file_data)
|
file_size = len(file_data)
|
||||||
|
|
||||||
# Calculate optimal signature parameters
|
# Calculate optimal signature parameters
|
||||||
magic, block_len, strong_len = get_signature_args(
|
magic, block_len, strong_len = get_signature_args(
|
||||||
file_size,
|
file_size, block_len=self.block_size, strong_len=self.strong_len
|
||||||
block_len=self.block_size,
|
|
||||||
strong_len=self.strong_len
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate signature using librsync
|
# Generate signature using librsync
|
||||||
file_io = BytesIO(file_data)
|
file_io = BytesIO(file_data)
|
||||||
sig_io = BytesIO()
|
sig_io = BytesIO()
|
||||||
|
|
||||||
signature(file_io, sig_io, strong_len, magic, block_len)
|
signature(file_io, sig_io, strong_len, magic, block_len)
|
||||||
signature_data = sig_io.getvalue()
|
signature_data = sig_io.getvalue()
|
||||||
|
|
||||||
# Parse signature to extract block information
|
# Parse signature to extract block information
|
||||||
blocks = self._parse_signature(signature_data, file_data, block_len)
|
blocks = self._parse_signature(signature_data, file_data, block_len)
|
||||||
|
|
||||||
return FileSignature(
|
return FileSignature(
|
||||||
file_path=file_path or "",
|
file_path=file_path or "",
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
|
|
@ -102,52 +103,56 @@ class BlockHashService:
|
||||||
block_size=block_len,
|
block_size=block_len,
|
||||||
strong_len=strong_len,
|
strong_len=strong_len,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
signature_data=signature_data
|
signature_data=signature_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_signature(self, signature_data: bytes, file_data: bytes, block_size: int) -> List[BlockInfo]:
|
def _parse_signature(
|
||||||
|
self, signature_data: bytes, file_data: bytes, block_size: int
|
||||||
|
) -> List[BlockInfo]:
|
||||||
"""
|
"""
|
||||||
Parse signature data to extract block information
|
Parse signature data to extract block information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signature_data: Raw signature data from librsync
|
signature_data: Raw signature data from librsync
|
||||||
file_data: Original file data
|
file_data: Original file data
|
||||||
block_size: Size of blocks
|
block_size: Size of blocks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of BlockInfo objects
|
List of BlockInfo objects
|
||||||
"""
|
"""
|
||||||
blocks = []
|
blocks = []
|
||||||
total_blocks = (len(file_data) + block_size - 1) // block_size
|
total_blocks = (len(file_data) + block_size - 1) // block_size
|
||||||
|
|
||||||
for i in range(total_blocks):
|
for i in range(total_blocks):
|
||||||
start_offset = i * block_size
|
start_offset = i * block_size
|
||||||
end_offset = min(start_offset + block_size, len(file_data))
|
end_offset = min(start_offset + block_size, len(file_data))
|
||||||
block_data = file_data[start_offset:end_offset]
|
block_data = file_data[start_offset:end_offset]
|
||||||
|
|
||||||
# Calculate weak checksum (simple Adler-32 variant)
|
# Calculate weak checksum (simple Adler-32 variant)
|
||||||
weak_checksum = self._calculate_weak_checksum(block_data)
|
weak_checksum = self._calculate_weak_checksum(block_data)
|
||||||
|
|
||||||
# Calculate strong hash (MD5)
|
# Calculate strong hash (MD5)
|
||||||
strong_hash = hashlib.md5(block_data).hexdigest()
|
strong_hash = hashlib.md5(block_data).hexdigest()
|
||||||
|
|
||||||
blocks.append(BlockInfo(
|
blocks.append(
|
||||||
block_index=i,
|
BlockInfo(
|
||||||
weak_checksum=weak_checksum,
|
block_index=i,
|
||||||
strong_hash=strong_hash,
|
weak_checksum=weak_checksum,
|
||||||
block_size=len(block_data),
|
strong_hash=strong_hash,
|
||||||
file_offset=start_offset
|
block_size=len(block_data),
|
||||||
))
|
file_offset=start_offset,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
def _calculate_weak_checksum(self, data: bytes) -> int:
|
def _calculate_weak_checksum(self, data: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate a weak checksum similar to Adler-32
|
Calculate a weak checksum similar to Adler-32
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Block data
|
data: Block data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Weak checksum value
|
Weak checksum value
|
||||||
"""
|
"""
|
||||||
|
|
@ -157,111 +162,116 @@ class BlockHashService:
|
||||||
a = (a + byte) % 65521
|
a = (a + byte) % 65521
|
||||||
b = (b + a) % 65521
|
b = (b + a) % 65521
|
||||||
return (b << 16) | a
|
return (b << 16) | a
|
||||||
|
|
||||||
def compare_signatures(self, old_sig: FileSignature, new_sig: FileSignature) -> List[int]:
|
def compare_signatures(self, old_sig: FileSignature, new_sig: FileSignature) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Compare two signatures to find changed blocks
|
Compare two signatures to find changed blocks
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
old_sig: Previous file signature
|
old_sig: Previous file signature
|
||||||
new_sig: New file signature
|
new_sig: New file signature
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of block indices that have changed
|
List of block indices that have changed
|
||||||
"""
|
"""
|
||||||
changed_blocks = []
|
changed_blocks = []
|
||||||
|
|
||||||
# Create lookup tables for efficient comparison
|
# Create lookup tables for efficient comparison
|
||||||
old_blocks = {block.block_index: block for block in old_sig.blocks}
|
old_blocks = {block.block_index: block for block in old_sig.blocks}
|
||||||
new_blocks = {block.block_index: block for block in new_sig.blocks}
|
new_blocks = {block.block_index: block for block in new_sig.blocks}
|
||||||
|
|
||||||
# Find changed, added, or removed blocks
|
# Find changed, added, or removed blocks
|
||||||
all_indices = set(old_blocks.keys()) | set(new_blocks.keys())
|
all_indices = set(old_blocks.keys()) | set(new_blocks.keys())
|
||||||
|
|
||||||
for block_idx in all_indices:
|
for block_idx in all_indices:
|
||||||
old_block = old_blocks.get(block_idx)
|
old_block = old_blocks.get(block_idx)
|
||||||
new_block = new_blocks.get(block_idx)
|
new_block = new_blocks.get(block_idx)
|
||||||
|
|
||||||
if old_block is None or new_block is None:
|
if old_block is None or new_block is None:
|
||||||
# Block was added or removed
|
# Block was added or removed
|
||||||
changed_blocks.append(block_idx)
|
changed_blocks.append(block_idx)
|
||||||
elif (old_block.weak_checksum != new_block.weak_checksum or
|
elif (
|
||||||
old_block.strong_hash != new_block.strong_hash):
|
old_block.weak_checksum != new_block.weak_checksum
|
||||||
|
or old_block.strong_hash != new_block.strong_hash
|
||||||
|
):
|
||||||
# Block content changed
|
# Block content changed
|
||||||
changed_blocks.append(block_idx)
|
changed_blocks.append(block_idx)
|
||||||
|
|
||||||
return sorted(changed_blocks)
|
return sorted(changed_blocks)
|
||||||
|
|
||||||
def generate_delta(self, old_file: BinaryIO, new_file: BinaryIO,
|
def generate_delta(
|
||||||
old_signature: FileSignature = None) -> FileDelta:
|
self, old_file: BinaryIO, new_file: BinaryIO, old_signature: FileSignature = None
|
||||||
|
) -> FileDelta:
|
||||||
"""
|
"""
|
||||||
Generate a delta between two file versions
|
Generate a delta between two file versions
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
old_file: Previous version of the file
|
old_file: Previous version of the file
|
||||||
new_file: New version of the file
|
new_file: New version of the file
|
||||||
old_signature: Optional pre-computed signature of old file
|
old_signature: Optional pre-computed signature of old file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FileDelta object containing change information
|
FileDelta object containing change information
|
||||||
"""
|
"""
|
||||||
# Generate signatures if not provided
|
# Generate signatures if not provided
|
||||||
if old_signature is None:
|
if old_signature is None:
|
||||||
old_signature = self.generate_signature(old_file)
|
old_signature = self.generate_signature(old_file)
|
||||||
|
|
||||||
new_signature = self.generate_signature(new_file)
|
new_signature = self.generate_signature(new_file)
|
||||||
|
|
||||||
# Generate delta using librsync
|
# Generate delta using librsync
|
||||||
new_file.seek(0)
|
new_file.seek(0)
|
||||||
old_sig_io = BytesIO(old_signature.signature_data)
|
old_sig_io = BytesIO(old_signature.signature_data)
|
||||||
delta_io = BytesIO()
|
delta_io = BytesIO()
|
||||||
|
|
||||||
delta(new_file, old_sig_io, delta_io)
|
delta(new_file, old_sig_io, delta_io)
|
||||||
delta_data = delta_io.getvalue()
|
delta_data = delta_io.getvalue()
|
||||||
|
|
||||||
# Find changed blocks
|
# Find changed blocks
|
||||||
changed_blocks = self.compare_signatures(old_signature, new_signature)
|
changed_blocks = self.compare_signatures(old_signature, new_signature)
|
||||||
|
|
||||||
return FileDelta(
|
return FileDelta(
|
||||||
changed_blocks=changed_blocks,
|
changed_blocks=changed_blocks,
|
||||||
delta_data=delta_data,
|
delta_data=delta_data,
|
||||||
old_signature=old_signature,
|
old_signature=old_signature,
|
||||||
new_signature=new_signature
|
new_signature=new_signature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO:
|
def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO:
|
||||||
"""
|
"""
|
||||||
Apply a delta to reconstruct the new file
|
Apply a delta to reconstruct the new file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
old_file: Original file
|
old_file: Original file
|
||||||
delta_obj: Delta information
|
delta_obj: Delta information
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BytesIO object containing the reconstructed file
|
BytesIO object containing the reconstructed file
|
||||||
"""
|
"""
|
||||||
old_file.seek(0)
|
old_file.seek(0)
|
||||||
delta_io = BytesIO(delta_obj.delta_data)
|
delta_io = BytesIO(delta_obj.delta_data)
|
||||||
result_io = BytesIO()
|
result_io = BytesIO()
|
||||||
|
|
||||||
patch(old_file, delta_io, result_io)
|
patch(old_file, delta_io, result_io)
|
||||||
result_io.seek(0)
|
result_io.seek(0)
|
||||||
|
|
||||||
return result_io
|
return result_io
|
||||||
|
|
||||||
def calculate_block_changes(self, old_sig: FileSignature, new_sig: FileSignature) -> Dict[str, Any]:
|
def calculate_block_changes(
|
||||||
|
self, old_sig: FileSignature, new_sig: FileSignature
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Calculate detailed statistics about block changes
|
Calculate detailed statistics about block changes
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
old_sig: Previous file signature
|
old_sig: Previous file signature
|
||||||
new_sig: New file signature
|
new_sig: New file signature
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with change statistics
|
Dictionary with change statistics
|
||||||
"""
|
"""
|
||||||
changed_blocks = self.compare_signatures(old_sig, new_sig)
|
changed_blocks = self.compare_signatures(old_sig, new_sig)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_old_blocks": len(old_sig.blocks),
|
"total_old_blocks": len(old_sig.blocks),
|
||||||
"total_new_blocks": len(new_sig.blocks),
|
"total_new_blocks": len(new_sig.blocks),
|
||||||
|
|
@ -271,4 +281,4 @@ class BlockHashService:
|
||||||
"compression_ratio": 1.0 - (len(changed_blocks) / max(len(old_sig.blocks), 1)),
|
"compression_ratio": 1.0 - (len(changed_blocks) / max(len(old_sig.blocks), 1)),
|
||||||
"old_file_size": old_sig.file_size,
|
"old_file_size": old_sig.file_size,
|
||||||
"new_file_size": new_sig.file_size,
|
"new_file_size": new_sig.file_size,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,99 +26,107 @@ class IncrementalLoader:
|
||||||
"""
|
"""
|
||||||
Incremental file loader using rsync algorithm for efficient updates
|
Incremental file loader using rsync algorithm for efficient updates
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, block_size: int = 1024, strong_len: int = 8):
|
def __init__(self, block_size: int = 1024, strong_len: int = 8):
|
||||||
"""
|
"""
|
||||||
Initialize the incremental loader
|
Initialize the incremental loader
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_size: Size of blocks in bytes for rsync algorithm
|
block_size: Size of blocks in bytes for rsync algorithm
|
||||||
strong_len: Length of strong hash in bytes
|
strong_len: Length of strong hash in bytes
|
||||||
"""
|
"""
|
||||||
self.block_service = BlockHashService(block_size, strong_len)
|
self.block_service = BlockHashService(block_size, strong_len)
|
||||||
|
|
||||||
async def should_process_file(self, file_obj: BinaryIO, data_id: str) -> Tuple[bool, Optional[Dict]]:
|
async def should_process_file(
|
||||||
|
self, file_obj: BinaryIO, data_id: str
|
||||||
|
) -> Tuple[bool, Optional[Dict]]:
|
||||||
"""
|
"""
|
||||||
Determine if a file should be processed based on incremental changes
|
Determine if a file should be processed based on incremental changes
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_obj: File object to check
|
file_obj: File object to check
|
||||||
data_id: Data ID for the file
|
data_id: Data ID for the file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (should_process, change_info)
|
Tuple of (should_process, change_info)
|
||||||
- should_process: True if file needs processing
|
- should_process: True if file needs processing
|
||||||
- change_info: Dictionary with change details if applicable
|
- change_info: Dictionary with change details if applicable
|
||||||
"""
|
"""
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Check if we have an existing signature for this file
|
# Check if we have an existing signature for this file
|
||||||
existing_signature = await self._get_existing_signature(session, data_id)
|
existing_signature = await self._get_existing_signature(session, data_id)
|
||||||
|
|
||||||
if existing_signature is None:
|
if existing_signature is None:
|
||||||
# First time seeing this file, needs full processing
|
# First time seeing this file, needs full processing
|
||||||
return True, {"type": "new_file", "full_processing": True}
|
return True, {"type": "new_file", "full_processing": True}
|
||||||
|
|
||||||
# Generate signature for current file version
|
# Generate signature for current file version
|
||||||
current_signature = self.block_service.generate_signature(file_obj)
|
current_signature = self.block_service.generate_signature(file_obj)
|
||||||
|
|
||||||
# Quick check: if overall content hash is the same, no changes
|
# Quick check: if overall content hash is the same, no changes
|
||||||
file_obj.seek(0)
|
file_obj.seek(0)
|
||||||
current_content_hash = get_file_content_hash(file_obj)
|
current_content_hash = get_file_content_hash(file_obj)
|
||||||
|
|
||||||
if current_content_hash == existing_signature.content_hash:
|
if current_content_hash == existing_signature.content_hash:
|
||||||
return False, {"type": "no_changes", "full_processing": False}
|
return False, {"type": "no_changes", "full_processing": False}
|
||||||
|
|
||||||
# Convert database signature to service signature for comparison
|
# Convert database signature to service signature for comparison
|
||||||
service_old_sig = self._db_signature_to_service(existing_signature)
|
service_old_sig = self._db_signature_to_service(existing_signature)
|
||||||
|
|
||||||
# Compare signatures to find changed blocks
|
# Compare signatures to find changed blocks
|
||||||
changed_blocks = self.block_service.compare_signatures(service_old_sig, current_signature)
|
changed_blocks = self.block_service.compare_signatures(
|
||||||
|
service_old_sig, current_signature
|
||||||
|
)
|
||||||
|
|
||||||
if not changed_blocks:
|
if not changed_blocks:
|
||||||
# Signatures match, no processing needed
|
# Signatures match, no processing needed
|
||||||
return False, {"type": "no_changes", "full_processing": False}
|
return False, {"type": "no_changes", "full_processing": False}
|
||||||
|
|
||||||
# Calculate change statistics
|
# Calculate change statistics
|
||||||
change_stats = self.block_service.calculate_block_changes(service_old_sig, current_signature)
|
change_stats = self.block_service.calculate_block_changes(
|
||||||
|
service_old_sig, current_signature
|
||||||
|
)
|
||||||
|
|
||||||
change_info = {
|
change_info = {
|
||||||
"type": "incremental_changes",
|
"type": "incremental_changes",
|
||||||
"full_processing": len(changed_blocks) > (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
|
"full_processing": len(changed_blocks)
|
||||||
|
> (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
|
||||||
"changed_blocks": changed_blocks,
|
"changed_blocks": changed_blocks,
|
||||||
"stats": change_stats,
|
"stats": change_stats,
|
||||||
"new_signature": current_signature,
|
"new_signature": current_signature,
|
||||||
"old_signature": service_old_sig,
|
"old_signature": service_old_sig,
|
||||||
}
|
}
|
||||||
|
|
||||||
return True, change_info
|
return True, change_info
|
||||||
|
|
||||||
async def process_incremental_changes(self, file_obj: BinaryIO, data_id: str,
|
async def process_incremental_changes(
|
||||||
change_info: Dict) -> List[Dict]:
|
self, file_obj: BinaryIO, data_id: str, change_info: Dict
|
||||||
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Process only the changed blocks of a file
|
Process only the changed blocks of a file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_obj: File object to process
|
file_obj: File object to process
|
||||||
data_id: Data ID for the file
|
data_id: Data ID for the file
|
||||||
change_info: Change information from should_process_file
|
change_info: Change information from should_process_file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of block data that needs reprocessing
|
List of block data that needs reprocessing
|
||||||
"""
|
"""
|
||||||
if change_info["type"] != "incremental_changes":
|
if change_info["type"] != "incremental_changes":
|
||||||
raise ValueError("Invalid change_info type for incremental processing")
|
raise ValueError("Invalid change_info type for incremental processing")
|
||||||
|
|
||||||
file_obj.seek(0)
|
file_obj.seek(0)
|
||||||
file_data = file_obj.read()
|
file_data = file_obj.read()
|
||||||
|
|
||||||
changed_blocks = change_info["changed_blocks"]
|
changed_blocks = change_info["changed_blocks"]
|
||||||
new_signature = change_info["new_signature"]
|
new_signature = change_info["new_signature"]
|
||||||
|
|
||||||
# Extract data for changed blocks
|
# Extract data for changed blocks
|
||||||
changed_block_data = []
|
changed_block_data = []
|
||||||
|
|
||||||
for block_idx in changed_blocks:
|
for block_idx in changed_blocks:
|
||||||
# Find the block info
|
# Find the block info
|
||||||
block_info = None
|
block_info = None
|
||||||
|
|
@ -126,49 +134,51 @@ class IncrementalLoader:
|
||||||
if block.block_index == block_idx:
|
if block.block_index == block_idx:
|
||||||
block_info = block
|
block_info = block
|
||||||
break
|
break
|
||||||
|
|
||||||
if block_info is None:
|
if block_info is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract block data
|
# Extract block data
|
||||||
start_offset = block_info.file_offset
|
start_offset = block_info.file_offset
|
||||||
end_offset = start_offset + block_info.block_size
|
end_offset = start_offset + block_info.block_size
|
||||||
block_data = file_data[start_offset:end_offset]
|
block_data = file_data[start_offset:end_offset]
|
||||||
|
|
||||||
changed_block_data.append({
|
changed_block_data.append(
|
||||||
"block_index": block_idx,
|
{
|
||||||
"block_data": block_data,
|
"block_index": block_idx,
|
||||||
"block_info": block_info,
|
"block_data": block_data,
|
||||||
"file_offset": start_offset,
|
"block_info": block_info,
|
||||||
"block_size": len(block_data),
|
"file_offset": start_offset,
|
||||||
})
|
"block_size": len(block_data),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return changed_block_data
|
return changed_block_data
|
||||||
|
|
||||||
async def save_file_signature(self, file_obj: BinaryIO, data_id: str) -> None:
|
async def save_file_signature(self, file_obj: BinaryIO, data_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Save or update the file signature in the database
|
Save or update the file signature in the database
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_obj: File object
|
file_obj: File object
|
||||||
data_id: Data ID for the file
|
data_id: Data ID for the file
|
||||||
"""
|
"""
|
||||||
# Generate signature
|
# Generate signature
|
||||||
signature = self.block_service.generate_signature(file_obj, str(data_id))
|
signature = self.block_service.generate_signature(file_obj, str(data_id))
|
||||||
|
|
||||||
# Calculate content hash
|
# Calculate content hash
|
||||||
file_obj.seek(0)
|
file_obj.seek(0)
|
||||||
content_hash = get_file_content_hash(file_obj)
|
content_hash = get_file_content_hash(file_obj)
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Check if signature already exists
|
# Check if signature already exists
|
||||||
existing = await session.execute(
|
existing = await session.execute(
|
||||||
select(FileSignature).filter(FileSignature.data_id == data_id)
|
select(FileSignature).filter(FileSignature.data_id == data_id)
|
||||||
)
|
)
|
||||||
existing_signature = existing.scalar_one_or_none()
|
existing_signature = existing.scalar_one_or_none()
|
||||||
|
|
||||||
# Prepare block info for JSON storage
|
# Prepare block info for JSON storage
|
||||||
blocks_info = [
|
blocks_info = [
|
||||||
{
|
{
|
||||||
|
|
@ -180,7 +190,7 @@ class IncrementalLoader:
|
||||||
}
|
}
|
||||||
for block in signature.blocks
|
for block in signature.blocks
|
||||||
]
|
]
|
||||||
|
|
||||||
if existing_signature:
|
if existing_signature:
|
||||||
# Update existing signature
|
# Update existing signature
|
||||||
existing_signature.file_path = signature.file_path
|
existing_signature.file_path = signature.file_path
|
||||||
|
|
@ -191,7 +201,7 @@ class IncrementalLoader:
|
||||||
existing_signature.strong_len = signature.strong_len
|
existing_signature.strong_len = signature.strong_len
|
||||||
existing_signature.signature_data = signature.signature_data
|
existing_signature.signature_data = signature.signature_data
|
||||||
existing_signature.blocks_info = blocks_info
|
existing_signature.blocks_info = blocks_info
|
||||||
|
|
||||||
await session.merge(existing_signature)
|
await session.merge(existing_signature)
|
||||||
else:
|
else:
|
||||||
# Create new signature
|
# Create new signature
|
||||||
|
|
@ -207,17 +217,19 @@ class IncrementalLoader:
|
||||||
blocks_info=blocks_info,
|
blocks_info=blocks_info,
|
||||||
)
|
)
|
||||||
session.add(new_signature)
|
session.add(new_signature)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def _get_existing_signature(self, session: AsyncSession, data_id: str) -> Optional[FileSignature]:
|
async def _get_existing_signature(
|
||||||
|
self, session: AsyncSession, data_id: str
|
||||||
|
) -> Optional[FileSignature]:
|
||||||
"""
|
"""
|
||||||
Get existing file signature from database
|
Get existing file signature from database
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: Database session
|
session: Database session
|
||||||
data_id: Data ID to search for
|
data_id: Data ID to search for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FileSignature object or None if not found
|
FileSignature object or None if not found
|
||||||
"""
|
"""
|
||||||
|
|
@ -225,19 +237,19 @@ class IncrementalLoader:
|
||||||
select(FileSignature).filter(FileSignature.data_id == data_id)
|
select(FileSignature).filter(FileSignature.data_id == data_id)
|
||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
def _db_signature_to_service(self, db_signature: FileSignature) -> ServiceFileSignature:
|
def _db_signature_to_service(self, db_signature: FileSignature) -> ServiceFileSignature:
|
||||||
"""
|
"""
|
||||||
Convert database FileSignature to service FileSignature
|
Convert database FileSignature to service FileSignature
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_signature: Database signature object
|
db_signature: Database signature object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Service FileSignature object
|
Service FileSignature object
|
||||||
"""
|
"""
|
||||||
from .block_hash_service import BlockInfo
|
from .block_hash_service import BlockInfo
|
||||||
|
|
||||||
# Convert blocks info
|
# Convert blocks info
|
||||||
blocks = [
|
blocks = [
|
||||||
BlockInfo(
|
BlockInfo(
|
||||||
|
|
@ -249,7 +261,7 @@ class IncrementalLoader:
|
||||||
)
|
)
|
||||||
for block in db_signature.blocks_info
|
for block in db_signature.blocks_info
|
||||||
]
|
]
|
||||||
|
|
||||||
return ServiceFileSignature(
|
return ServiceFileSignature(
|
||||||
file_path=db_signature.file_path,
|
file_path=db_signature.file_path,
|
||||||
file_size=db_signature.file_size,
|
file_size=db_signature.file_size,
|
||||||
|
|
@ -259,26 +271,26 @@ class IncrementalLoader:
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
signature_data=db_signature.signature_data,
|
signature_data=db_signature.signature_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def cleanup_orphaned_signatures(self) -> int:
|
async def cleanup_orphaned_signatures(self) -> int:
|
||||||
"""
|
"""
|
||||||
Clean up file signatures that no longer have corresponding data entries
|
Clean up file signatures that no longer have corresponding data entries
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of signatures removed
|
Number of signatures removed
|
||||||
"""
|
"""
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Find signatures without corresponding data entries
|
# Find signatures without corresponding data entries
|
||||||
orphaned_query = """
|
orphaned_query = """
|
||||||
DELETE FROM file_signatures
|
DELETE FROM file_signatures
|
||||||
WHERE data_id NOT IN (SELECT id FROM data)
|
WHERE data_id NOT IN (SELECT id FROM data)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await session.execute(orphaned_query)
|
result = await session.execute(orphaned_query)
|
||||||
removed_count = result.rowcount
|
removed_count = result.rowcount
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return removed_count
|
return removed_count
|
||||||
|
|
|
||||||
|
|
@ -109,13 +109,15 @@ async def ingest_data(
|
||||||
data_id = ingestion.identify(classified_data, user)
|
data_id = ingestion.identify(classified_data, user)
|
||||||
|
|
||||||
file_metadata = classified_data.get_metadata()
|
file_metadata = classified_data.get_metadata()
|
||||||
|
|
||||||
# Initialize incremental loader for this file
|
# Initialize incremental loader for this file
|
||||||
incremental_loader = IncrementalLoader()
|
incremental_loader = IncrementalLoader()
|
||||||
|
|
||||||
# Check if file needs incremental processing
|
# Check if file needs incremental processing
|
||||||
should_process, change_info = await incremental_loader.should_process_file(file, data_id)
|
should_process, change_info = await incremental_loader.should_process_file(
|
||||||
|
file, data_id
|
||||||
|
)
|
||||||
|
|
||||||
# Save updated file signature regardless of whether processing is needed
|
# Save updated file signature regardless of whether processing is needed
|
||||||
await incremental_loader.save_file_signature(file, data_id)
|
await incremental_loader.save_file_signature(file, data_id)
|
||||||
|
|
||||||
|
|
@ -150,12 +152,12 @@ async def ingest_data(
|
||||||
ext_metadata = get_external_metadata_dict(data_item)
|
ext_metadata = get_external_metadata_dict(data_item)
|
||||||
if node_set:
|
if node_set:
|
||||||
ext_metadata["node_set"] = node_set
|
ext_metadata["node_set"] = node_set
|
||||||
|
|
||||||
# Add incremental processing metadata
|
# Add incremental processing metadata
|
||||||
ext_metadata["incremental_processing"] = {
|
ext_metadata["incremental_processing"] = {
|
||||||
"should_process": should_process,
|
"should_process": should_process,
|
||||||
"change_info": change_info,
|
"change_info": change_info,
|
||||||
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat()))
|
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat())),
|
||||||
}
|
}
|
||||||
|
|
||||||
if data_point is not None:
|
if data_point is not None:
|
||||||
|
|
|
||||||
|
|
@ -51,12 +51,12 @@ def test_AudioDocument(mock_engine):
|
||||||
GROUND_TRUTH,
|
GROUND_TRUTH,
|
||||||
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -34,12 +34,12 @@ def test_ImageDocument(mock_engine):
|
||||||
GROUND_TRUTH,
|
GROUND_TRUTH,
|
||||||
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -36,12 +36,12 @@ def test_PdfDocument(mock_engine):
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
|
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -49,12 +49,12 @@ def test_TextDocument(mock_engine, input_file, chunk_size):
|
||||||
GROUND_TRUTH[input_file],
|
GROUND_TRUTH[input_file],
|
||||||
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
|
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -79,32 +79,32 @@ def test_UnstructuredDocument(mock_engine):
|
||||||
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
|
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
|
||||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test DOCX
|
# Test DOCX
|
||||||
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
|
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
|
||||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_end" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_end != {paragraph_data.cut_type = }"
|
"sentence_end" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_end != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# TEST CSV
|
# TEST CSV
|
||||||
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
|
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
|
||||||
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
assert (
|
||||||
f"Read text doesn't match expected text: {paragraph_data.text}"
|
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
|
||||||
)
|
), f"Read text doesn't match expected text: {paragraph_data.text}"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test XLSX
|
# Test XLSX
|
||||||
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
|
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
|
||||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ async def check_graph_metrics_consistency_across_adapters(include_optional=False
|
||||||
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
||||||
|
|
||||||
for key, neo4j_value in neo4j_metrics.items():
|
for key, neo4j_value in neo4j_metrics.items():
|
||||||
assert networkx_metrics[key] == neo4j_value, (
|
assert (
|
||||||
f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
|
networkx_metrics[key] == neo4j_value
|
||||||
)
|
), f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,6 @@ async def assert_metrics(provider, include_optional=True):
|
||||||
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
||||||
|
|
||||||
for key, ground_truth_value in ground_truth_metrics.items():
|
for key, ground_truth_value in ground_truth_metrics.items():
|
||||||
assert metrics[key] == ground_truth_value, (
|
assert (
|
||||||
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
|
metrics[key] == ground_truth_value
|
||||||
)
|
), f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
|
||||||
|
|
|
||||||
|
|
@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
|
||||||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||||
# Get data entry from database based on hash contents
|
# Get data entry from database based on hash contents
|
||||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test deletion of data along with local files created by cognee
|
# Test deletion of data along with local files created by cognee
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert not os.path.exists(data.raw_data_location), (
|
assert not os.path.exists(
|
||||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||||
|
|
||||||
async with engine.get_async_session() as session:
|
async with engine.get_async_session() as session:
|
||||||
# Get data entry from database based on file path
|
# Get data entry from database based on file path
|
||||||
data = (
|
data = (
|
||||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||||
).one()
|
).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test local files not created by cognee won't get deleted
|
# Test local files not created by cognee won't get deleted
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert os.path.exists(data.raw_data_location), (
|
assert os.path.exists(
|
||||||
f"Data location doesn't exists: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exists: {data.raw_data_location}"
|
||||||
|
|
||||||
|
|
||||||
async def test_getting_of_documents(dataset_name_1):
|
async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||||
assert len(document_ids) == 1, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
len(document_ids) == 1
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||||
|
|
||||||
# Test getting of documents for search when no dataset is provided
|
# Test getting of documents for search when no dataset is provided
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id)
|
document_ids = await get_document_ids_for_user(user.id)
|
||||||
assert len(document_ids) == 2, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
len(document_ids) == 2
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,9 @@ async def test_deduplication():
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("data")
|
result = await relational_engine.get_all_data_from_table("data")
|
||||||
assert len(result) == 1, "More than one data entity was found."
|
assert len(result) == 1, "More than one data entity was found."
|
||||||
assert result[0]["name"] == "Natural_language_processing_copy", (
|
assert (
|
||||||
"Result name does not match expected value."
|
result[0]["name"] == "Natural_language_processing_copy"
|
||||||
)
|
), "Result name does not match expected value."
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("datasets")
|
result = await relational_engine.get_all_data_from_table("datasets")
|
||||||
assert len(result) == 2, "Unexpected number of datasets found."
|
assert len(result) == 2, "Unexpected number of datasets found."
|
||||||
|
|
@ -61,9 +61,9 @@ async def test_deduplication():
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("data")
|
result = await relational_engine.get_all_data_from_table("data")
|
||||||
assert len(result) == 1, "More than one data entity was found."
|
assert len(result) == 1, "More than one data entity was found."
|
||||||
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
|
assert (
|
||||||
"Content hash is not a part of file name."
|
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
|
||||||
)
|
), "Content hash is not a part of file name."
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
|
||||||
|
|
@ -92,9 +92,9 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
assert not os.path.exists(get_relational_engine().db_path), (
|
assert not os.path.exists(
|
||||||
"SQLite relational database is not empty"
|
get_relational_engine().db_path
|
||||||
)
|
), "SQLite relational database is not empty"
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,13 +103,13 @@ async def main():
|
||||||
node_name=["nonexistent"],
|
node_name=["nonexistent"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
isinstance(context_nonempty, str) and context_nonempty != ""
|
||||||
)
|
), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||||
|
|
||||||
assert context_empty == "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
context_empty == ""
|
||||||
)
|
), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
|
||||||
|
|
@ -107,13 +107,13 @@ async def main():
|
||||||
node_name=["nonexistent"],
|
node_name=["nonexistent"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
isinstance(context_nonempty, str) and context_nonempty != ""
|
||||||
)
|
), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||||
|
|
||||||
assert context_empty == "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
context_empty == ""
|
||||||
)
|
), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
|
||||||
|
|
@ -23,28 +23,28 @@ async def test_local_file_deletion(data_text, file_location):
|
||||||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||||
# Get data entry from database based on hash contents
|
# Get data entry from database based on hash contents
|
||||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test deletion of data along with local files created by cognee
|
# Test deletion of data along with local files created by cognee
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert not os.path.exists(data.raw_data_location), (
|
assert not os.path.exists(
|
||||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||||
|
|
||||||
async with engine.get_async_session() as session:
|
async with engine.get_async_session() as session:
|
||||||
# Get data entry from database based on file path
|
# Get data entry from database based on file path
|
||||||
data = (
|
data = (
|
||||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||||
).one()
|
).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test local files not created by cognee won't get deleted
|
# Test local files not created by cognee won't get deleted
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert os.path.exists(data.raw_data_location), (
|
assert os.path.exists(
|
||||||
f"Data location doesn't exists: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exists: {data.raw_data_location}"
|
||||||
|
|
||||||
|
|
||||||
async def test_getting_of_documents(dataset_name_1):
|
async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
@ -53,16 +53,16 @@ async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||||
assert len(document_ids) == 1, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
len(document_ids) == 1
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||||
|
|
||||||
# Test getting of documents for search when no dataset is provided
|
# Test getting of documents for search when no dataset is provided
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id)
|
document_ids = await get_document_ids_for_user(user.id)
|
||||||
assert len(document_ids) == 2, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
len(document_ids) == 2
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
||||||
|
|
@ -112,9 +112,9 @@ async def relational_db_migration():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||||
|
|
||||||
assert len(distinct_node_names) == 12, (
|
assert (
|
||||||
f"Expected 12 distinct node references, found {len(distinct_node_names)}"
|
len(distinct_node_names) == 12
|
||||||
)
|
), f"Expected 12 distinct node references, found {len(distinct_node_names)}"
|
||||||
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
|
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
|
||||||
|
|
||||||
expected_edges = {
|
expected_edges = {
|
||||||
|
|
|
||||||
|
|
@ -29,54 +29,54 @@ async def main():
|
||||||
logging.info(edge_type_counts)
|
logging.info(edge_type_counts)
|
||||||
|
|
||||||
# Assert there is exactly one PdfDocument.
|
# Assert there is exactly one PdfDocument.
|
||||||
assert type_counts.get("PdfDocument", 0) == 1, (
|
assert (
|
||||||
f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
|
type_counts.get("PdfDocument", 0) == 1
|
||||||
)
|
), f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
|
||||||
|
|
||||||
# Assert there is exactly one TextDocument.
|
# Assert there is exactly one TextDocument.
|
||||||
assert type_counts.get("TextDocument", 0) == 1, (
|
assert (
|
||||||
f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
|
type_counts.get("TextDocument", 0) == 1
|
||||||
)
|
), f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
|
||||||
|
|
||||||
# Assert there are at least two DocumentChunk nodes.
|
# Assert there are at least two DocumentChunk nodes.
|
||||||
assert type_counts.get("DocumentChunk", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
|
type_counts.get("DocumentChunk", 0) >= 2
|
||||||
)
|
), f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
|
||||||
|
|
||||||
# Assert there is at least two TextSummary.
|
# Assert there is at least two TextSummary.
|
||||||
assert type_counts.get("TextSummary", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
|
type_counts.get("TextSummary", 0) >= 2
|
||||||
)
|
), f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
|
||||||
|
|
||||||
# Assert there is at least one Entity.
|
# Assert there is at least one Entity.
|
||||||
assert type_counts.get("Entity", 0) > 0, (
|
assert (
|
||||||
f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
|
type_counts.get("Entity", 0) > 0
|
||||||
)
|
), f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
|
||||||
|
|
||||||
# Assert there is at least one EntityType.
|
# Assert there is at least one EntityType.
|
||||||
assert type_counts.get("EntityType", 0) > 0, (
|
assert (
|
||||||
f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
|
type_counts.get("EntityType", 0) > 0
|
||||||
)
|
), f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
|
||||||
|
|
||||||
# Assert that there are at least two 'is_part_of' edges.
|
# Assert that there are at least two 'is_part_of' edges.
|
||||||
assert edge_type_counts.get("is_part_of", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
|
edge_type_counts.get("is_part_of", 0) >= 2
|
||||||
)
|
), f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
|
||||||
|
|
||||||
# Assert that there are at least two 'made_from' edges.
|
# Assert that there are at least two 'made_from' edges.
|
||||||
assert edge_type_counts.get("made_from", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
|
edge_type_counts.get("made_from", 0) >= 2
|
||||||
)
|
), f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
|
||||||
|
|
||||||
# Assert that there is at least one 'is_a' edge.
|
# Assert that there is at least one 'is_a' edge.
|
||||||
assert edge_type_counts.get("is_a", 0) >= 1, (
|
assert (
|
||||||
f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
|
edge_type_counts.get("is_a", 0) >= 1
|
||||||
)
|
), f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
|
||||||
|
|
||||||
# Assert that there is at least one 'contains' edge.
|
# Assert that there is at least one 'contains' edge.
|
||||||
assert edge_type_counts.get("contains", 0) >= 1, (
|
assert (
|
||||||
f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
|
edge_type_counts.get("contains", 0) >= 1
|
||||||
)
|
), f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,9 @@ async def main():
|
||||||
assert isinstance(context, str), f"{name}: Context should be a string"
|
assert isinstance(context, str), f"{name}: Context should be a string"
|
||||||
assert context.strip(), f"{name}: Context should not be empty"
|
assert context.strip(), f"{name}: Context should not be empty"
|
||||||
lower = context.lower()
|
lower = context.lower()
|
||||||
assert "germany" in lower or "netherlands" in lower, (
|
assert (
|
||||||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
"germany" in lower or "netherlands" in lower
|
||||||
)
|
), f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
||||||
|
|
||||||
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
||||||
query="Next to which country is Germany located?"
|
query="Next to which country is Germany located?"
|
||||||
|
|
@ -96,18 +96,18 @@ async def main():
|
||||||
distance = edge.attributes.get("vector_distance")
|
distance = edge.attributes.get("vector_distance")
|
||||||
node1_distance = edge.node1.attributes.get("vector_distance")
|
node1_distance = edge.node1.attributes.get("vector_distance")
|
||||||
node2_distance = edge.node2.attributes.get("vector_distance")
|
node2_distance = edge.node2.attributes.get("vector_distance")
|
||||||
assert isinstance(distance, float), (
|
assert isinstance(
|
||||||
f"{name}: vector_distance should be float, got {type(distance)}"
|
distance, float
|
||||||
)
|
), f"{name}: vector_distance should be float, got {type(distance)}"
|
||||||
assert 0 <= distance <= 1, (
|
assert (
|
||||||
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
0 <= distance <= 1
|
||||||
)
|
), f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
assert 0 <= node1_distance <= 1, (
|
assert (
|
||||||
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
0 <= node1_distance <= 1
|
||||||
)
|
), f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
assert 0 <= node2_distance <= 1, (
|
assert (
|
||||||
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
0 <= node2_distance <= 1
|
||||||
)
|
), f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
|
|
||||||
completion_gk = await cognee.search(
|
completion_gk = await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
|
@ -137,9 +137,9 @@ async def main():
|
||||||
text = completion[0]
|
text = completion[0]
|
||||||
assert isinstance(text, str), f"{name}: element should be a string"
|
assert isinstance(text, str), f"{name}: element should be a string"
|
||||||
assert text.strip(), f"{name}: string should not be empty"
|
assert text.strip(), f"{name}: string should not be empty"
|
||||||
assert "netherlands" in text.lower(), (
|
assert (
|
||||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
"netherlands" in text.lower()
|
||||||
)
|
), f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,12 @@ async def test_answer_generation():
|
||||||
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
|
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
|
||||||
|
|
||||||
assert len(answers) == len(qa_pairs)
|
assert len(answers) == len(qa_pairs)
|
||||||
assert answers[0]["question"] == qa_pairs[0]["question"], (
|
assert (
|
||||||
"AnswerGeneratorExecutor is passing the question incorrectly"
|
answers[0]["question"] == qa_pairs[0]["question"]
|
||||||
)
|
), "AnswerGeneratorExecutor is passing the question incorrectly"
|
||||||
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
|
assert (
|
||||||
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
answers[0]["golden_answer"] == qa_pairs[0]["answer"]
|
||||||
)
|
), "AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
||||||
assert answers[0]["answer"] == "Mocked answer", (
|
assert (
|
||||||
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
answers[0]["answer"] == "Mocked answer"
|
||||||
)
|
), "AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,9 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
||||||
|
|
||||||
corpus_list, qa_pairs = result
|
corpus_list, qa_pairs = result
|
||||||
assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list."
|
assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list."
|
||||||
assert isinstance(qa_pairs, list), (
|
assert isinstance(
|
||||||
f"{AdapterClass.__name__} question_answer_pairs is not a list."
|
qa_pairs, list
|
||||||
)
|
), f"{AdapterClass.__name__} question_answer_pairs is not a list."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES)
|
@pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES)
|
||||||
|
|
@ -71,9 +71,9 @@ def test_adapter_returns_some_content(AdapterClass):
|
||||||
# We don't know how large the dataset is, but we expect at least 1 item
|
# We don't know how large the dataset is, but we expect at least 1 item
|
||||||
assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list."
|
assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list."
|
||||||
assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs."
|
assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs."
|
||||||
assert len(qa_pairs) <= limit, (
|
assert (
|
||||||
f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
|
len(qa_pairs) <= limit
|
||||||
)
|
), f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
|
||||||
|
|
||||||
for item in qa_pairs:
|
for item in qa_pairs:
|
||||||
assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair."
|
assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair."
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ def test_corpus_builder_load_corpus(benchmark):
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||||
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
||||||
assert len(questions) <= 2, (
|
assert (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
len(questions) <= 2
|
||||||
)
|
), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -24,6 +24,6 @@ async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
||||||
limit = 2
|
limit = 2
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
questions = await corpus_builder.build_corpus(limit=limit)
|
questions = await corpus_builder.build_corpus(limit=limit)
|
||||||
assert len(questions) <= 2, (
|
assert (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
len(questions) <= 2
|
||||||
)
|
), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
|
|
|
||||||
|
|
@ -52,14 +52,14 @@ def test_metrics(metrics, actual, expected, expected_exact_score, expected_f1_ra
|
||||||
test_case = MockTestCase(actual, expected)
|
test_case = MockTestCase(actual, expected)
|
||||||
|
|
||||||
exact_match_score = metrics["exact_match"].measure(test_case)
|
exact_match_score = metrics["exact_match"].measure(test_case)
|
||||||
assert exact_match_score == expected_exact_score, (
|
assert (
|
||||||
f"Exact match failed for '{actual}' vs '{expected}'"
|
exact_match_score == expected_exact_score
|
||||||
)
|
), f"Exact match failed for '{actual}' vs '{expected}'"
|
||||||
|
|
||||||
f1_score = metrics["f1"].measure(test_case)
|
f1_score = metrics["f1"].measure(test_case)
|
||||||
assert expected_f1_range[0] <= f1_score <= expected_f1_range[1], (
|
assert (
|
||||||
f"F1 score failed for '{actual}' vs '{expected}'"
|
expected_f1_range[0] <= f1_score <= expected_f1_range[1]
|
||||||
)
|
), f"F1 score failed for '{actual}' vs '{expected}'"
|
||||||
|
|
||||||
|
|
||||||
class TestBootstrapCI(unittest.TestCase):
|
class TestBootstrapCI(unittest.TestCase):
|
||||||
|
|
|
||||||
|
|
@ -157,15 +157,15 @@ def test_rate_limit_60_per_minute():
|
||||||
if len(failures) > 0:
|
if len(failures) > 0:
|
||||||
first_failure_idx = int(failures[0].split()[1])
|
first_failure_idx = int(failures[0].split()[1])
|
||||||
print(f"First failure occurred at request index: {first_failure_idx}")
|
print(f"First failure occurred at request index: {first_failure_idx}")
|
||||||
assert 58 <= first_failure_idx <= 62, (
|
assert (
|
||||||
f"Expected first failure around request #60, got #{first_failure_idx}"
|
58 <= first_failure_idx <= 62
|
||||||
)
|
), f"Expected first failure around request #60, got #{first_failure_idx}"
|
||||||
|
|
||||||
# Calculate requests per minute
|
# Calculate requests per minute
|
||||||
rate_per_minute = len(successes)
|
rate_per_minute = len(successes)
|
||||||
print(f"Rate: {rate_per_minute} requests per minute")
|
print(f"Rate: {rate_per_minute} requests per minute")
|
||||||
|
|
||||||
# Verify the rate is close to our target of 60 requests per minute
|
# Verify the rate is close to our target of 60 requests per minute
|
||||||
assert 58 <= rate_per_minute <= 62, (
|
assert (
|
||||||
f"Expected rate of ~60 requests per minute, got {rate_per_minute}"
|
58 <= rate_per_minute <= 62
|
||||||
)
|
), f"Expected rate of ~60 requests per minute, got {rate_per_minute}"
|
||||||
|
|
|
||||||
|
|
@ -110,9 +110,9 @@ def test_sync_retry():
|
||||||
print(f"Number of attempts: {test_function_sync.counter}")
|
print(f"Number of attempts: {test_function_sync.counter}")
|
||||||
|
|
||||||
# The function should succeed on the 3rd attempt (after 2 failures)
|
# The function should succeed on the 3rd attempt (after 2 failures)
|
||||||
assert test_function_sync.counter == 3, (
|
assert (
|
||||||
f"Expected 3 attempts, got {test_function_sync.counter}"
|
test_function_sync.counter == 3
|
||||||
)
|
), f"Expected 3 attempts, got {test_function_sync.counter}"
|
||||||
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
||||||
|
|
||||||
print("✅ PASS: Synchronous retry mechanism is working correctly")
|
print("✅ PASS: Synchronous retry mechanism is working correctly")
|
||||||
|
|
@ -143,9 +143,9 @@ async def test_async_retry():
|
||||||
print(f"Number of attempts: {test_function_async.counter}")
|
print(f"Number of attempts: {test_function_async.counter}")
|
||||||
|
|
||||||
# The function should succeed on the 3rd attempt (after 2 failures)
|
# The function should succeed on the 3rd attempt (after 2 failures)
|
||||||
assert test_function_async.counter == 3, (
|
assert (
|
||||||
f"Expected 3 attempts, got {test_function_async.counter}"
|
test_function_async.counter == 3
|
||||||
)
|
), f"Expected 3 attempts, got {test_function_async.counter}"
|
||||||
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
||||||
|
|
||||||
print("✅ PASS: Asynchronous retry mechanism is working correctly")
|
print("✅ PASS: Asynchronous retry mechanism is working correctly")
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Canva?")
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_extension_context_complex(self):
|
async def test_graph_completion_extension_context_complex(self):
|
||||||
|
|
@ -136,9 +136,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||||
|
|
@ -167,9 +167,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Canva?")
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_cot_context_complex(self):
|
async def test_graph_completion_cot_context_complex(self):
|
||||||
|
|
@ -134,9 +134,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||||
|
|
@ -165,9 +165,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,9 @@ max_chunk_size_vals = [512, 1024, 4096]
|
||||||
def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs):
|
def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs):
|
||||||
chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs)
|
chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs)
|
||||||
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
|
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -54,9 +54,9 @@ def test_paragraph_chunk_length(input_text, max_chunk_size, batch_paragraphs):
|
||||||
)
|
)
|
||||||
|
|
||||||
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
||||||
assert np.all(chunk_lengths <= max_chunk_size), (
|
assert np.all(
|
||||||
f"{max_chunk_size = }: {larger_chunks} are too large"
|
chunk_lengths <= max_chunk_size
|
||||||
)
|
), f"{max_chunk_size = }: {larger_chunks} are too large"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -76,6 +76,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, max_chunk_size, batch_pa
|
||||||
batch_paragraphs=batch_paragraphs,
|
batch_paragraphs=batch_paragraphs,
|
||||||
)
|
)
|
||||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||||
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
assert np.all(
|
||||||
f"{chunk_indices = } are not monotonically increasing"
|
chunk_indices == np.arange(len(chunk_indices))
|
||||||
)
|
), f"{chunk_indices = } are not monotonically increasing"
|
||||||
|
|
|
||||||
|
|
@ -71,9 +71,9 @@ def run_chunking_test(test_text, expected_chunks, mock_engine):
|
||||||
|
|
||||||
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
|
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
|
||||||
for key in ["text", "chunk_size", "cut_type"]:
|
for key in ["text", "chunk_size", "cut_type"]:
|
||||||
assert chunk[key] == expected_chunks_item[key], (
|
assert (
|
||||||
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
chunk[key] == expected_chunks_item[key]
|
||||||
)
|
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
||||||
|
|
||||||
|
|
||||||
def test_chunking_whole_text():
|
def test_chunking_whole_text():
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ maximum_length_vals = [None, 16, 64]
|
||||||
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
|
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
|
||||||
chunks = chunk_by_sentence(input_text, maximum_length)
|
chunks = chunk_by_sentence(input_text, maximum_length)
|
||||||
reconstructed_text = "".join([chunk[1] for chunk in chunks])
|
reconstructed_text = "".join([chunk[1] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -40,9 +40,9 @@ def test_paragraph_chunk_length(input_text, maximum_length):
|
||||||
)
|
)
|
||||||
|
|
||||||
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
|
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
|
||||||
assert np.all(chunk_lengths <= maximum_length), (
|
assert np.all(
|
||||||
f"{maximum_length = }: {larger_chunks} are too large"
|
chunk_lengths <= maximum_length
|
||||||
)
|
), f"{maximum_length = }: {larger_chunks} are too large"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS, INPUT_TE
|
||||||
def test_chunk_by_word_isomorphism(input_text):
|
def test_chunk_by_word_isomorphism(input_text):
|
||||||
chunks = chunk_by_word(input_text)
|
chunks = chunk_by_word(input_text)
|
||||||
reconstructed_text = "".join([chunk[0] for chunk in chunks])
|
reconstructed_text = "".join([chunk[0] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,14 @@ async def demonstrate_incremental_loading():
|
||||||
Demonstrate incremental file loading by creating a file, modifying it,
|
Demonstrate incremental file loading by creating a file, modifying it,
|
||||||
and showing how only changed blocks are detected.
|
and showing how only changed blocks are detected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("🚀 Cognee Incremental File Loading Demo")
|
print("🚀 Cognee Incremental File Loading Demo")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# Initialize the incremental loader
|
# Initialize the incremental loader
|
||||||
incremental_loader = IncrementalLoader(block_size=512) # 512 byte blocks for demo
|
incremental_loader = IncrementalLoader(block_size=512) # 512 byte blocks for demo
|
||||||
block_service = BlockHashService(block_size=512)
|
block_service = BlockHashService(block_size=512)
|
||||||
|
|
||||||
# Create initial file content
|
# Create initial file content
|
||||||
initial_content = b"""
|
initial_content = b"""
|
||||||
This is the initial content of our test file.
|
This is the initial content of our test file.
|
||||||
|
|
@ -40,7 +40,7 @@ Block 5: Excepteur sint occaecat cupidatat non proident, sunt in culpa.
|
||||||
|
|
||||||
This is the end of the initial content.
|
This is the end of the initial content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create modified content (change Block 2 and add Block 6)
|
# Create modified content (change Block 2 and add Block 6)
|
||||||
modified_content = b"""
|
modified_content = b"""
|
||||||
This is the initial content of our test file.
|
This is the initial content of our test file.
|
||||||
|
|
@ -56,64 +56,70 @@ Block 6: NEW BLOCK - This is additional content that was added.
|
||||||
|
|
||||||
This is the end of the modified content.
|
This is the end of the modified content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("1. Creating signatures for initial and modified versions...")
|
print("1. Creating signatures for initial and modified versions...")
|
||||||
|
|
||||||
# Generate signatures
|
# Generate signatures
|
||||||
initial_file = BytesIO(initial_content)
|
initial_file = BytesIO(initial_content)
|
||||||
modified_file = BytesIO(modified_content)
|
modified_file = BytesIO(modified_content)
|
||||||
|
|
||||||
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
||||||
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
||||||
|
|
||||||
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks")
|
print(
|
||||||
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks")
|
f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
|
||||||
# Compare signatures to find changes
|
# Compare signatures to find changes
|
||||||
print("\n2. Comparing signatures to detect changes...")
|
print("\n2. Comparing signatures to detect changes...")
|
||||||
|
|
||||||
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
|
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
|
||||||
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
|
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
|
||||||
|
|
||||||
print(f" Changed blocks: {changed_blocks}")
|
print(f" Changed blocks: {changed_blocks}")
|
||||||
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
||||||
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}")
|
print(
|
||||||
|
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Generate delta
|
# Generate delta
|
||||||
print("\n3. Generating delta for changed content...")
|
print("\n3. Generating delta for changed content...")
|
||||||
|
|
||||||
initial_file.seek(0)
|
initial_file.seek(0)
|
||||||
modified_file.seek(0)
|
modified_file.seek(0)
|
||||||
|
|
||||||
delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
|
delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
|
||||||
|
|
||||||
print(f" Delta size: {len(delta.delta_data)} bytes")
|
print(f" Delta size: {len(delta.delta_data)} bytes")
|
||||||
print(f" Changed blocks in delta: {delta.changed_blocks}")
|
print(f" Changed blocks in delta: {delta.changed_blocks}")
|
||||||
|
|
||||||
# Demonstrate reconstruction
|
# Demonstrate reconstruction
|
||||||
print("\n4. Reconstructing file from delta...")
|
print("\n4. Reconstructing file from delta...")
|
||||||
|
|
||||||
initial_file.seek(0)
|
initial_file.seek(0)
|
||||||
reconstructed = block_service.apply_delta(initial_file, delta)
|
reconstructed = block_service.apply_delta(initial_file, delta)
|
||||||
reconstructed_content = reconstructed.read()
|
reconstructed_content = reconstructed.read()
|
||||||
|
|
||||||
print(f" Reconstruction successful: {reconstructed_content == modified_content}")
|
print(f" Reconstruction successful: {reconstructed_content == modified_content}")
|
||||||
print(f" Reconstructed size: {len(reconstructed_content)} bytes")
|
print(f" Reconstructed size: {len(reconstructed_content)} bytes")
|
||||||
|
|
||||||
# Show block details
|
# Show block details
|
||||||
print("\n5. Block-by-block analysis:")
|
print("\n5. Block-by-block analysis:")
|
||||||
print(" Block | Status | Strong Hash (first 8 chars)")
|
print(" Block | Status | Strong Hash (first 8 chars)")
|
||||||
print(" ------|----------|---------------------------")
|
print(" ------|----------|---------------------------")
|
||||||
|
|
||||||
old_blocks = {b.block_index: b for b in initial_signature.blocks}
|
old_blocks = {b.block_index: b for b in initial_signature.blocks}
|
||||||
new_blocks = {b.block_index: b for b in modified_signature.blocks}
|
new_blocks = {b.block_index: b for b in modified_signature.blocks}
|
||||||
|
|
||||||
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
|
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
|
||||||
|
|
||||||
for idx in all_indices:
|
for idx in all_indices:
|
||||||
old_block = old_blocks.get(idx)
|
old_block = old_blocks.get(idx)
|
||||||
new_block = new_blocks.get(idx)
|
new_block = new_blocks.get(idx)
|
||||||
|
|
||||||
if old_block is None:
|
if old_block is None:
|
||||||
status = "ADDED"
|
status = "ADDED"
|
||||||
hash_display = new_block.strong_hash[:8] if new_block else ""
|
hash_display = new_block.strong_hash[:8] if new_block else ""
|
||||||
|
|
@ -126,9 +132,9 @@ This is the end of the modified content.
|
||||||
else:
|
else:
|
||||||
status = "MODIFIED"
|
status = "MODIFIED"
|
||||||
hash_display = f"{old_block.strong_hash[:8]}→{new_block.strong_hash[:8]}"
|
hash_display = f"{old_block.strong_hash[:8]}→{new_block.strong_hash[:8]}"
|
||||||
|
|
||||||
print(f" {idx:5d} | {status:8s} | {hash_display}")
|
print(f" {idx:5d} | {status:8s} | {hash_display}")
|
||||||
|
|
||||||
print("\n✅ Incremental loading demo completed!")
|
print("\n✅ Incremental loading demo completed!")
|
||||||
print("\nThis demonstrates how Cognee can efficiently process only the changed")
|
print("\nThis demonstrates how Cognee can efficiently process only the changed")
|
||||||
print("parts of files, significantly reducing processing time for large files")
|
print("parts of files, significantly reducing processing time for large files")
|
||||||
|
|
@ -139,35 +145,35 @@ async def demonstrate_with_cognee():
|
||||||
"""
|
"""
|
||||||
Demonstrate integration with Cognee's add functionality
|
Demonstrate integration with Cognee's add functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
print("🔧 Integration with Cognee Add Functionality")
|
print("🔧 Integration with Cognee Add Functionality")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# Create a temporary file
|
# Create a temporary file
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||||
f.write("Initial content for Cognee processing.")
|
f.write("Initial content for Cognee processing.")
|
||||||
temp_file_path = f.name
|
temp_file_path = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"1. Adding initial file: {temp_file_path}")
|
print(f"1. Adding initial file: {temp_file_path}")
|
||||||
|
|
||||||
# Add file to Cognee
|
# Add file to Cognee
|
||||||
await cognee.add(temp_file_path)
|
await cognee.add(temp_file_path)
|
||||||
|
|
||||||
print(" ✅ File added successfully")
|
print(" ✅ File added successfully")
|
||||||
|
|
||||||
# Modify the file
|
# Modify the file
|
||||||
with open(temp_file_path, 'w') as f:
|
with open(temp_file_path, "w") as f:
|
||||||
f.write("Modified content for Cognee processing with additional text.")
|
f.write("Modified content for Cognee processing with additional text.")
|
||||||
|
|
||||||
print("2. Adding modified version of the same file...")
|
print("2. Adding modified version of the same file...")
|
||||||
|
|
||||||
# Add modified file - this should trigger incremental processing
|
# Add modified file - this should trigger incremental processing
|
||||||
await cognee.add(temp_file_path)
|
await cognee.add(temp_file_path)
|
||||||
|
|
||||||
print(" ✅ Modified file processed with incremental loading")
|
print(" ✅ Modified file processed with incremental loading")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up
|
# Clean up
|
||||||
if os.path.exists(temp_file_path):
|
if os.path.exists(temp_file_path):
|
||||||
|
|
@ -176,11 +182,11 @@ async def demonstrate_with_cognee():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
print("Starting Cognee Incremental Loading Demo...")
|
print("Starting Cognee Incremental Loading Demo...")
|
||||||
|
|
||||||
# Run the demonstration
|
# Run the demonstration
|
||||||
asyncio.run(demonstrate_incremental_loading())
|
asyncio.run(demonstrate_incremental_loading())
|
||||||
|
|
||||||
# Uncomment the line below to test with actual Cognee integration
|
# Uncomment the line below to test with actual Cognee integration
|
||||||
# asyncio.run(demonstrate_with_cognee())
|
# asyncio.run(demonstrate_with_cognee())
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@ Simple test for incremental loading functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, '.')
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from cognee.modules.ingestion.incremental import BlockHashService
|
from cognee.modules.ingestion.incremental import BlockHashService
|
||||||
|
|
@ -14,13 +15,13 @@ def test_incremental_loading():
|
||||||
"""
|
"""
|
||||||
Simple test of the incremental loading functionality
|
Simple test of the incremental loading functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("🚀 Cognee Incremental File Loading Test")
|
print("🚀 Cognee Incremental File Loading Test")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# Initialize the block service
|
# Initialize the block service
|
||||||
block_service = BlockHashService(block_size=64) # Small blocks for demo
|
block_service = BlockHashService(block_size=64) # Small blocks for demo
|
||||||
|
|
||||||
# Create initial file content
|
# Create initial file content
|
||||||
initial_content = b"""This is the initial content.
|
initial_content = b"""This is the initial content.
|
||||||
Line 1: Lorem ipsum dolor sit amet
|
Line 1: Lorem ipsum dolor sit amet
|
||||||
|
|
@ -28,7 +29,7 @@ Line 2: Consectetur adipiscing elit
|
||||||
Line 3: Sed do eiusmod tempor
|
Line 3: Sed do eiusmod tempor
|
||||||
Line 4: Incididunt ut labore et dolore
|
Line 4: Incididunt ut labore et dolore
|
||||||
Line 5: End of initial content"""
|
Line 5: End of initial content"""
|
||||||
|
|
||||||
# Create modified content (change Line 2 and add Line 6)
|
# Create modified content (change Line 2 and add Line 6)
|
||||||
modified_content = b"""This is the initial content.
|
modified_content = b"""This is the initial content.
|
||||||
Line 1: Lorem ipsum dolor sit amet
|
Line 1: Lorem ipsum dolor sit amet
|
||||||
|
|
@ -37,64 +38,70 @@ Line 3: Sed do eiusmod tempor
|
||||||
Line 4: Incididunt ut labore et dolore
|
Line 4: Incididunt ut labore et dolore
|
||||||
Line 5: End of initial content
|
Line 5: End of initial content
|
||||||
Line 6: NEW - This is additional content"""
|
Line 6: NEW - This is additional content"""
|
||||||
|
|
||||||
print("1. Creating signatures for initial and modified versions...")
|
print("1. Creating signatures for initial and modified versions...")
|
||||||
|
|
||||||
# Generate signatures
|
# Generate signatures
|
||||||
initial_file = BytesIO(initial_content)
|
initial_file = BytesIO(initial_content)
|
||||||
modified_file = BytesIO(modified_content)
|
modified_file = BytesIO(modified_content)
|
||||||
|
|
||||||
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
||||||
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
||||||
|
|
||||||
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks")
|
print(
|
||||||
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks")
|
f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
|
||||||
# Compare signatures to find changes
|
# Compare signatures to find changes
|
||||||
print("\n2. Comparing signatures to detect changes...")
|
print("\n2. Comparing signatures to detect changes...")
|
||||||
|
|
||||||
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
|
changed_blocks = block_service.compare_signatures(initial_signature, modified_signature)
|
||||||
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
|
change_stats = block_service.calculate_block_changes(initial_signature, modified_signature)
|
||||||
|
|
||||||
print(f" Changed blocks: {changed_blocks}")
|
print(f" Changed blocks: {changed_blocks}")
|
||||||
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
||||||
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}")
|
print(
|
||||||
|
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Generate delta
|
# Generate delta
|
||||||
print("\n3. Generating delta for changed content...")
|
print("\n3. Generating delta for changed content...")
|
||||||
|
|
||||||
initial_file.seek(0)
|
initial_file.seek(0)
|
||||||
modified_file.seek(0)
|
modified_file.seek(0)
|
||||||
|
|
||||||
delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
|
delta = block_service.generate_delta(initial_file, modified_file, initial_signature)
|
||||||
|
|
||||||
print(f" Delta size: {len(delta.delta_data)} bytes")
|
print(f" Delta size: {len(delta.delta_data)} bytes")
|
||||||
print(f" Changed blocks in delta: {delta.changed_blocks}")
|
print(f" Changed blocks in delta: {delta.changed_blocks}")
|
||||||
|
|
||||||
# Demonstrate reconstruction
|
# Demonstrate reconstruction
|
||||||
print("\n4. Reconstructing file from delta...")
|
print("\n4. Reconstructing file from delta...")
|
||||||
|
|
||||||
initial_file.seek(0)
|
initial_file.seek(0)
|
||||||
reconstructed = block_service.apply_delta(initial_file, delta)
|
reconstructed = block_service.apply_delta(initial_file, delta)
|
||||||
reconstructed_content = reconstructed.read()
|
reconstructed_content = reconstructed.read()
|
||||||
|
|
||||||
print(f" Reconstruction successful: {reconstructed_content == modified_content}")
|
print(f" Reconstruction successful: {reconstructed_content == modified_content}")
|
||||||
print(f" Reconstructed size: {len(reconstructed_content)} bytes")
|
print(f" Reconstructed size: {len(reconstructed_content)} bytes")
|
||||||
|
|
||||||
# Show block details
|
# Show block details
|
||||||
print("\n5. Block-by-block analysis:")
|
print("\n5. Block-by-block analysis:")
|
||||||
print(" Block | Status | Strong Hash (first 8 chars)")
|
print(" Block | Status | Strong Hash (first 8 chars)")
|
||||||
print(" ------|----------|---------------------------")
|
print(" ------|----------|---------------------------")
|
||||||
|
|
||||||
old_blocks = {b.block_index: b for b in initial_signature.blocks}
|
old_blocks = {b.block_index: b for b in initial_signature.blocks}
|
||||||
new_blocks = {b.block_index: b for b in modified_signature.blocks}
|
new_blocks = {b.block_index: b for b in modified_signature.blocks}
|
||||||
|
|
||||||
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
|
all_indices = sorted(set(old_blocks.keys()) | set(new_blocks.keys()))
|
||||||
|
|
||||||
for idx in all_indices:
|
for idx in all_indices:
|
||||||
old_block = old_blocks.get(idx)
|
old_block = old_blocks.get(idx)
|
||||||
new_block = new_blocks.get(idx)
|
new_block = new_blocks.get(idx)
|
||||||
|
|
||||||
if old_block is None:
|
if old_block is None:
|
||||||
status = "ADDED"
|
status = "ADDED"
|
||||||
hash_display = new_block.strong_hash[:8] if new_block else ""
|
hash_display = new_block.strong_hash[:8] if new_block else ""
|
||||||
|
|
@ -107,14 +114,14 @@ Line 6: NEW - This is additional content"""
|
||||||
else:
|
else:
|
||||||
status = "MODIFIED"
|
status = "MODIFIED"
|
||||||
hash_display = f"{old_block.strong_hash[:8]}→{new_block.strong_hash[:8]}"
|
hash_display = f"{old_block.strong_hash[:8]}→{new_block.strong_hash[:8]}"
|
||||||
|
|
||||||
print(f" {idx:5d} | {status:8s} | {hash_display}")
|
print(f" {idx:5d} | {status:8s} | {hash_display}")
|
||||||
|
|
||||||
print("\n✅ Incremental loading test completed!")
|
print("\n✅ Incremental loading test completed!")
|
||||||
print("\nThis demonstrates how Cognee can efficiently process only the changed")
|
print("\nThis demonstrates how Cognee can efficiently process only the changed")
|
||||||
print("parts of files, significantly reducing processing time for large files")
|
print("parts of files, significantly reducing processing time for large files")
|
||||||
print("with small modifications.")
|
print("with small modifications.")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -124,4 +131,4 @@ if __name__ == "__main__":
|
||||||
print("\n🎉 Test passed successfully!")
|
print("\n🎉 Test passed successfully!")
|
||||||
else:
|
else:
|
||||||
print("\n❌ Test failed!")
|
print("\n❌ Test failed!")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
||||||
|
|
@ -9,96 +9,96 @@ from cognee.modules.ingestion.incremental import BlockHashService, IncrementalLo
|
||||||
|
|
||||||
class TestBlockHashService:
|
class TestBlockHashService:
|
||||||
"""Test the core block hashing service"""
|
"""Test the core block hashing service"""
|
||||||
|
|
||||||
def test_signature_generation(self):
|
def test_signature_generation(self):
|
||||||
"""Test basic signature generation"""
|
"""Test basic signature generation"""
|
||||||
service = BlockHashService(block_size=10)
|
service = BlockHashService(block_size=10)
|
||||||
|
|
||||||
content = b"Hello, this is a test file for block hashing!"
|
content = b"Hello, this is a test file for block hashing!"
|
||||||
file_obj = BytesIO(content)
|
file_obj = BytesIO(content)
|
||||||
|
|
||||||
signature = service.generate_signature(file_obj, "test.txt")
|
signature = service.generate_signature(file_obj, "test.txt")
|
||||||
|
|
||||||
assert signature.file_path == "test.txt"
|
assert signature.file_path == "test.txt"
|
||||||
assert signature.file_size == len(content)
|
assert signature.file_size == len(content)
|
||||||
assert signature.block_size == 10
|
assert signature.block_size == 10
|
||||||
assert len(signature.blocks) > 0
|
assert len(signature.blocks) > 0
|
||||||
assert signature.signature_data is not None
|
assert signature.signature_data is not None
|
||||||
|
|
||||||
def test_change_detection(self):
|
def test_change_detection(self):
|
||||||
"""Test detection of changes between file versions"""
|
"""Test detection of changes between file versions"""
|
||||||
service = BlockHashService(block_size=10)
|
service = BlockHashService(block_size=10)
|
||||||
|
|
||||||
# Original content
|
# Original content
|
||||||
original_content = b"Hello, world! This is the original content."
|
original_content = b"Hello, world! This is the original content."
|
||||||
original_file = BytesIO(original_content)
|
original_file = BytesIO(original_content)
|
||||||
original_sig = service.generate_signature(original_file)
|
original_sig = service.generate_signature(original_file)
|
||||||
|
|
||||||
# Modified content (change in middle)
|
# Modified content (change in middle)
|
||||||
modified_content = b"Hello, world! This is the MODIFIED content."
|
modified_content = b"Hello, world! This is the MODIFIED content."
|
||||||
modified_file = BytesIO(modified_content)
|
modified_file = BytesIO(modified_content)
|
||||||
modified_sig = service.generate_signature(modified_file)
|
modified_sig = service.generate_signature(modified_file)
|
||||||
|
|
||||||
# Check for changes
|
# Check for changes
|
||||||
changed_blocks = service.compare_signatures(original_sig, modified_sig)
|
changed_blocks = service.compare_signatures(original_sig, modified_sig)
|
||||||
|
|
||||||
assert len(changed_blocks) > 0 # Should detect changes
|
assert len(changed_blocks) > 0 # Should detect changes
|
||||||
assert len(changed_blocks) < len(original_sig.blocks) # Not all blocks changed
|
assert len(changed_blocks) < len(original_sig.blocks) # Not all blocks changed
|
||||||
|
|
||||||
def test_no_changes(self):
|
def test_no_changes(self):
|
||||||
"""Test that identical files show no changes"""
|
"""Test that identical files show no changes"""
|
||||||
service = BlockHashService(block_size=10)
|
service = BlockHashService(block_size=10)
|
||||||
|
|
||||||
content = b"This content will not change at all!"
|
content = b"This content will not change at all!"
|
||||||
|
|
||||||
file1 = BytesIO(content)
|
file1 = BytesIO(content)
|
||||||
file2 = BytesIO(content)
|
file2 = BytesIO(content)
|
||||||
|
|
||||||
sig1 = service.generate_signature(file1)
|
sig1 = service.generate_signature(file1)
|
||||||
sig2 = service.generate_signature(file2)
|
sig2 = service.generate_signature(file2)
|
||||||
|
|
||||||
changed_blocks = service.compare_signatures(sig1, sig2)
|
changed_blocks = service.compare_signatures(sig1, sig2)
|
||||||
|
|
||||||
assert len(changed_blocks) == 0
|
assert len(changed_blocks) == 0
|
||||||
|
|
||||||
def test_delta_generation(self):
|
def test_delta_generation(self):
|
||||||
"""Test delta generation and application"""
|
"""Test delta generation and application"""
|
||||||
service = BlockHashService(block_size=8)
|
service = BlockHashService(block_size=8)
|
||||||
|
|
||||||
original_content = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
original_content = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
modified_content = b"ABCDEFGHXXXXXXXXXXXXXXWXYZ" # Change middle part
|
modified_content = b"ABCDEFGHXXXXXXXXXXXXXXWXYZ" # Change middle part
|
||||||
|
|
||||||
original_file = BytesIO(original_content)
|
original_file = BytesIO(original_content)
|
||||||
modified_file = BytesIO(modified_content)
|
modified_file = BytesIO(modified_content)
|
||||||
|
|
||||||
# Generate delta
|
# Generate delta
|
||||||
delta = service.generate_delta(original_file, modified_file)
|
delta = service.generate_delta(original_file, modified_file)
|
||||||
|
|
||||||
assert len(delta.changed_blocks) > 0
|
assert len(delta.changed_blocks) > 0
|
||||||
assert delta.delta_data is not None
|
assert delta.delta_data is not None
|
||||||
|
|
||||||
# Apply delta
|
# Apply delta
|
||||||
original_file.seek(0)
|
original_file.seek(0)
|
||||||
reconstructed = service.apply_delta(original_file, delta)
|
reconstructed = service.apply_delta(original_file, delta)
|
||||||
reconstructed_content = reconstructed.read()
|
reconstructed_content = reconstructed.read()
|
||||||
|
|
||||||
assert reconstructed_content == modified_content
|
assert reconstructed_content == modified_content
|
||||||
|
|
||||||
def test_block_statistics(self):
|
def test_block_statistics(self):
|
||||||
"""Test calculation of block change statistics"""
|
"""Test calculation of block change statistics"""
|
||||||
service = BlockHashService(block_size=5)
|
service = BlockHashService(block_size=5)
|
||||||
|
|
||||||
old_content = b"ABCDEFGHIJ" # 2 blocks
|
old_content = b"ABCDEFGHIJ" # 2 blocks
|
||||||
new_content = b"ABCDEFXXXX" # 2 blocks, second one changed
|
new_content = b"ABCDEFXXXX" # 2 blocks, second one changed
|
||||||
|
|
||||||
old_file = BytesIO(old_content)
|
old_file = BytesIO(old_content)
|
||||||
new_file = BytesIO(new_content)
|
new_file = BytesIO(new_content)
|
||||||
|
|
||||||
old_sig = service.generate_signature(old_file)
|
old_sig = service.generate_signature(old_file)
|
||||||
new_sig = service.generate_signature(new_file)
|
new_sig = service.generate_signature(new_file)
|
||||||
|
|
||||||
stats = service.calculate_block_changes(old_sig, new_sig)
|
stats = service.calculate_block_changes(old_sig, new_sig)
|
||||||
|
|
||||||
assert stats["total_old_blocks"] == 2
|
assert stats["total_old_blocks"] == 2
|
||||||
assert stats["total_new_blocks"] == 2
|
assert stats["total_new_blocks"] == 2
|
||||||
assert stats["changed_blocks"] == 1 # Only second block changed
|
assert stats["changed_blocks"] == 1 # Only second block changed
|
||||||
|
|
@ -107,36 +107,36 @@ class TestBlockHashService:
|
||||||
|
|
||||||
class TestIncrementalLoader:
|
class TestIncrementalLoader:
|
||||||
"""Test the incremental loader integration"""
|
"""Test the incremental loader integration"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_should_process_new_file(self):
|
async def test_should_process_new_file(self):
|
||||||
"""Test processing decision for new files"""
|
"""Test processing decision for new files"""
|
||||||
loader = IncrementalLoader()
|
loader = IncrementalLoader()
|
||||||
|
|
||||||
content = b"This is a new file that hasn't been seen before."
|
content = b"This is a new file that hasn't been seen before."
|
||||||
file_obj = BytesIO(content)
|
file_obj = BytesIO(content)
|
||||||
|
|
||||||
# For a new file (no existing signature), should process
|
# For a new file (no existing signature), should process
|
||||||
# Note: This test would need a mock database setup in real implementation
|
# Note: This test would need a mock database setup in real implementation
|
||||||
# For now, we test the logic without database interaction
|
# For now, we test the logic without database interaction
|
||||||
pass # Placeholder for database-dependent test
|
pass # Placeholder for database-dependent test
|
||||||
|
|
||||||
def test_block_data_extraction(self):
|
def test_block_data_extraction(self):
|
||||||
"""Test extraction of changed block data"""
|
"""Test extraction of changed block data"""
|
||||||
loader = IncrementalLoader(block_size=10)
|
loader = IncrementalLoader(block_size=10)
|
||||||
|
|
||||||
content = b"Block1____Block2____Block3____"
|
content = b"Block1____Block2____Block3____"
|
||||||
file_obj = BytesIO(content)
|
file_obj = BytesIO(content)
|
||||||
|
|
||||||
# Create mock change info
|
# Create mock change info
|
||||||
from cognee.modules.ingestion.incremental.block_hash_service import BlockInfo, FileSignature
|
from cognee.modules.ingestion.incremental.block_hash_service import BlockInfo, FileSignature
|
||||||
|
|
||||||
blocks = [
|
blocks = [
|
||||||
BlockInfo(0, 12345, "hash1", 10, 0),
|
BlockInfo(0, 12345, "hash1", 10, 0),
|
||||||
BlockInfo(1, 23456, "hash2", 10, 10),
|
BlockInfo(1, 23456, "hash2", 10, 10),
|
||||||
BlockInfo(2, 34567, "hash3", 10, 20),
|
BlockInfo(2, 34567, "hash3", 10, 20),
|
||||||
]
|
]
|
||||||
|
|
||||||
signature = FileSignature(
|
signature = FileSignature(
|
||||||
file_path="test",
|
file_path="test",
|
||||||
file_size=30,
|
file_size=30,
|
||||||
|
|
@ -144,19 +144,19 @@ class TestIncrementalLoader:
|
||||||
block_size=10,
|
block_size=10,
|
||||||
strong_len=8,
|
strong_len=8,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
signature_data=b"signature"
|
signature_data=b"signature",
|
||||||
)
|
)
|
||||||
|
|
||||||
change_info = {
|
change_info = {
|
||||||
"type": "incremental_changes",
|
"type": "incremental_changes",
|
||||||
"changed_blocks": [1], # Only middle block changed
|
"changed_blocks": [1], # Only middle block changed
|
||||||
"new_signature": signature
|
"new_signature": signature,
|
||||||
}
|
}
|
||||||
|
|
||||||
# This would normally be called after should_process_file
|
# This would normally be called after should_process_file
|
||||||
# Testing the block extraction logic
|
# Testing the block extraction logic
|
||||||
pass # Placeholder for full integration test
|
pass # Placeholder for full integration test
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue