reformat applied
This commit is contained in:
parent
1ebeeac61d
commit
07f2afa69d
39 changed files with 577 additions and 531 deletions
|
|
@ -5,43 +5,49 @@ Revises: 1d0bb7fede17
|
||||||
Create Date: 2025-01-27 12:00:00.000000
|
Create Date: 2025-01-27 12:00:00.000000
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = 'incremental_file_signatures'
|
revision = "incremental_file_signatures"
|
||||||
down_revision = '1d0bb7fede17'
|
down_revision = "1d0bb7fede17"
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('file_signatures',
|
op.create_table(
|
||||||
sa.Column('id', sa.UUID(), nullable=False, default=uuid4),
|
"file_signatures",
|
||||||
sa.Column('data_id', sa.UUID(), nullable=True),
|
sa.Column("id", sa.UUID(), nullable=False, default=uuid4),
|
||||||
sa.Column('file_path', sa.String(), nullable=True),
|
sa.Column("data_id", sa.UUID(), nullable=True),
|
||||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
sa.Column("file_path", sa.String(), nullable=True),
|
||||||
sa.Column('content_hash', sa.String(), nullable=True),
|
sa.Column("file_size", sa.Integer(), nullable=True),
|
||||||
sa.Column('total_blocks', sa.Integer(), nullable=True),
|
sa.Column("content_hash", sa.String(), nullable=True),
|
||||||
sa.Column('block_size', sa.Integer(), nullable=True),
|
sa.Column("total_blocks", sa.Integer(), nullable=True),
|
||||||
sa.Column('strong_len', sa.Integer(), nullable=True),
|
sa.Column("block_size", sa.Integer(), nullable=True),
|
||||||
sa.Column('signature_data', sa.LargeBinary(), nullable=True),
|
sa.Column("strong_len", sa.Integer(), nullable=True),
|
||||||
sa.Column('blocks_info', sa.JSON(), nullable=True),
|
sa.Column("signature_data", sa.LargeBinary(), nullable=True),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
sa.Column("blocks_info", sa.JSON(), nullable=True),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_file_signatures_data_id"), "file_signatures", ["data_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_file_signatures_content_hash"), "file_signatures", ["content_hash"], unique=False
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_file_signatures_data_id'), 'file_signatures', ['data_id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_file_signatures_content_hash'), 'file_signatures', ['content_hash'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_index(op.f('ix_file_signatures_content_hash'), table_name='file_signatures')
|
op.drop_index(op.f("ix_file_signatures_content_hash"), table_name="file_signatures")
|
||||||
op.drop_index(op.f('ix_file_signatures_data_id'), table_name='file_signatures')
|
op.drop_index(op.f("ix_file_signatures_data_id"), table_name="file_signatures")
|
||||||
op.drop_table('file_signatures')
|
op.drop_table("file_signatures")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
@ -18,6 +18,7 @@ litellm.set_verbose = False
|
||||||
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
||||||
logging.getLogger("litellm").setLevel(logging.CRITICAL)
|
logging.getLogger("litellm").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
Adapter for Generic API LLM provider API.
|
Adapter for Generic API LLM provider API.
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,9 @@ class FileSignature(Base):
|
||||||
signature_data = Column(LargeBinary)
|
signature_data = Column(LargeBinary)
|
||||||
|
|
||||||
# Block details (JSON array of block info)
|
# Block details (JSON array of block info)
|
||||||
blocks_info = Column(JSON) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
|
blocks_info = Column(
|
||||||
|
JSON
|
||||||
|
) # Array of {block_index, weak_checksum, strong_hash, block_size, file_offset}
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import tempfile
|
||||||
@dataclass
|
@dataclass
|
||||||
class BlockInfo:
|
class BlockInfo:
|
||||||
"""Information about a file block"""
|
"""Information about a file block"""
|
||||||
|
|
||||||
block_index: int
|
block_index: int
|
||||||
weak_checksum: int
|
weak_checksum: int
|
||||||
strong_hash: str
|
strong_hash: str
|
||||||
|
|
@ -28,6 +29,7 @@ class BlockInfo:
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileSignature:
|
class FileSignature:
|
||||||
"""File signature containing block information"""
|
"""File signature containing block information"""
|
||||||
|
|
||||||
file_path: str
|
file_path: str
|
||||||
file_size: int
|
file_size: int
|
||||||
total_blocks: int
|
total_blocks: int
|
||||||
|
|
@ -40,6 +42,7 @@ class FileSignature:
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileDelta:
|
class FileDelta:
|
||||||
"""Delta information for changed blocks"""
|
"""Delta information for changed blocks"""
|
||||||
|
|
||||||
changed_blocks: List[int] # Block indices that changed
|
changed_blocks: List[int] # Block indices that changed
|
||||||
delta_data: bytes
|
delta_data: bytes
|
||||||
old_signature: FileSignature
|
old_signature: FileSignature
|
||||||
|
|
@ -80,9 +83,7 @@ class BlockHashService:
|
||||||
|
|
||||||
# Calculate optimal signature parameters
|
# Calculate optimal signature parameters
|
||||||
magic, block_len, strong_len = get_signature_args(
|
magic, block_len, strong_len = get_signature_args(
|
||||||
file_size,
|
file_size, block_len=self.block_size, strong_len=self.strong_len
|
||||||
block_len=self.block_size,
|
|
||||||
strong_len=self.strong_len
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate signature using librsync
|
# Generate signature using librsync
|
||||||
|
|
@ -102,10 +103,12 @@ class BlockHashService:
|
||||||
block_size=block_len,
|
block_size=block_len,
|
||||||
strong_len=strong_len,
|
strong_len=strong_len,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
signature_data=signature_data
|
signature_data=signature_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_signature(self, signature_data: bytes, file_data: bytes, block_size: int) -> List[BlockInfo]:
|
def _parse_signature(
|
||||||
|
self, signature_data: bytes, file_data: bytes, block_size: int
|
||||||
|
) -> List[BlockInfo]:
|
||||||
"""
|
"""
|
||||||
Parse signature data to extract block information
|
Parse signature data to extract block information
|
||||||
|
|
||||||
|
|
@ -131,13 +134,15 @@ class BlockHashService:
|
||||||
# Calculate strong hash (MD5)
|
# Calculate strong hash (MD5)
|
||||||
strong_hash = hashlib.md5(block_data).hexdigest()
|
strong_hash = hashlib.md5(block_data).hexdigest()
|
||||||
|
|
||||||
blocks.append(BlockInfo(
|
blocks.append(
|
||||||
block_index=i,
|
BlockInfo(
|
||||||
weak_checksum=weak_checksum,
|
block_index=i,
|
||||||
strong_hash=strong_hash,
|
weak_checksum=weak_checksum,
|
||||||
block_size=len(block_data),
|
strong_hash=strong_hash,
|
||||||
file_offset=start_offset
|
block_size=len(block_data),
|
||||||
))
|
file_offset=start_offset,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
@ -185,15 +190,18 @@ class BlockHashService:
|
||||||
if old_block is None or new_block is None:
|
if old_block is None or new_block is None:
|
||||||
# Block was added or removed
|
# Block was added or removed
|
||||||
changed_blocks.append(block_idx)
|
changed_blocks.append(block_idx)
|
||||||
elif (old_block.weak_checksum != new_block.weak_checksum or
|
elif (
|
||||||
old_block.strong_hash != new_block.strong_hash):
|
old_block.weak_checksum != new_block.weak_checksum
|
||||||
|
or old_block.strong_hash != new_block.strong_hash
|
||||||
|
):
|
||||||
# Block content changed
|
# Block content changed
|
||||||
changed_blocks.append(block_idx)
|
changed_blocks.append(block_idx)
|
||||||
|
|
||||||
return sorted(changed_blocks)
|
return sorted(changed_blocks)
|
||||||
|
|
||||||
def generate_delta(self, old_file: BinaryIO, new_file: BinaryIO,
|
def generate_delta(
|
||||||
old_signature: FileSignature = None) -> FileDelta:
|
self, old_file: BinaryIO, new_file: BinaryIO, old_signature: FileSignature = None
|
||||||
|
) -> FileDelta:
|
||||||
"""
|
"""
|
||||||
Generate a delta between two file versions
|
Generate a delta between two file versions
|
||||||
|
|
||||||
|
|
@ -226,7 +234,7 @@ class BlockHashService:
|
||||||
changed_blocks=changed_blocks,
|
changed_blocks=changed_blocks,
|
||||||
delta_data=delta_data,
|
delta_data=delta_data,
|
||||||
old_signature=old_signature,
|
old_signature=old_signature,
|
||||||
new_signature=new_signature
|
new_signature=new_signature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO:
|
def apply_delta(self, old_file: BinaryIO, delta_obj: FileDelta) -> BytesIO:
|
||||||
|
|
@ -249,7 +257,9 @@ class BlockHashService:
|
||||||
|
|
||||||
return result_io
|
return result_io
|
||||||
|
|
||||||
def calculate_block_changes(self, old_sig: FileSignature, new_sig: FileSignature) -> Dict[str, Any]:
|
def calculate_block_changes(
|
||||||
|
self, old_sig: FileSignature, new_sig: FileSignature
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Calculate detailed statistics about block changes
|
Calculate detailed statistics about block changes
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,9 @@ class IncrementalLoader:
|
||||||
"""
|
"""
|
||||||
self.block_service = BlockHashService(block_size, strong_len)
|
self.block_service = BlockHashService(block_size, strong_len)
|
||||||
|
|
||||||
async def should_process_file(self, file_obj: BinaryIO, data_id: str) -> Tuple[bool, Optional[Dict]]:
|
async def should_process_file(
|
||||||
|
self, file_obj: BinaryIO, data_id: str
|
||||||
|
) -> Tuple[bool, Optional[Dict]]:
|
||||||
"""
|
"""
|
||||||
Determine if a file should be processed based on incremental changes
|
Determine if a file should be processed based on incremental changes
|
||||||
|
|
||||||
|
|
@ -74,18 +76,23 @@ class IncrementalLoader:
|
||||||
service_old_sig = self._db_signature_to_service(existing_signature)
|
service_old_sig = self._db_signature_to_service(existing_signature)
|
||||||
|
|
||||||
# Compare signatures to find changed blocks
|
# Compare signatures to find changed blocks
|
||||||
changed_blocks = self.block_service.compare_signatures(service_old_sig, current_signature)
|
changed_blocks = self.block_service.compare_signatures(
|
||||||
|
service_old_sig, current_signature
|
||||||
|
)
|
||||||
|
|
||||||
if not changed_blocks:
|
if not changed_blocks:
|
||||||
# Signatures match, no processing needed
|
# Signatures match, no processing needed
|
||||||
return False, {"type": "no_changes", "full_processing": False}
|
return False, {"type": "no_changes", "full_processing": False}
|
||||||
|
|
||||||
# Calculate change statistics
|
# Calculate change statistics
|
||||||
change_stats = self.block_service.calculate_block_changes(service_old_sig, current_signature)
|
change_stats = self.block_service.calculate_block_changes(
|
||||||
|
service_old_sig, current_signature
|
||||||
|
)
|
||||||
|
|
||||||
change_info = {
|
change_info = {
|
||||||
"type": "incremental_changes",
|
"type": "incremental_changes",
|
||||||
"full_processing": len(changed_blocks) > (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
|
"full_processing": len(changed_blocks)
|
||||||
|
> (len(service_old_sig.blocks) * 0.7), # >70% changed = full reprocess
|
||||||
"changed_blocks": changed_blocks,
|
"changed_blocks": changed_blocks,
|
||||||
"stats": change_stats,
|
"stats": change_stats,
|
||||||
"new_signature": current_signature,
|
"new_signature": current_signature,
|
||||||
|
|
@ -94,8 +101,9 @@ class IncrementalLoader:
|
||||||
|
|
||||||
return True, change_info
|
return True, change_info
|
||||||
|
|
||||||
async def process_incremental_changes(self, file_obj: BinaryIO, data_id: str,
|
async def process_incremental_changes(
|
||||||
change_info: Dict) -> List[Dict]:
|
self, file_obj: BinaryIO, data_id: str, change_info: Dict
|
||||||
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Process only the changed blocks of a file
|
Process only the changed blocks of a file
|
||||||
|
|
||||||
|
|
@ -135,13 +143,15 @@ class IncrementalLoader:
|
||||||
end_offset = start_offset + block_info.block_size
|
end_offset = start_offset + block_info.block_size
|
||||||
block_data = file_data[start_offset:end_offset]
|
block_data = file_data[start_offset:end_offset]
|
||||||
|
|
||||||
changed_block_data.append({
|
changed_block_data.append(
|
||||||
"block_index": block_idx,
|
{
|
||||||
"block_data": block_data,
|
"block_index": block_idx,
|
||||||
"block_info": block_info,
|
"block_data": block_data,
|
||||||
"file_offset": start_offset,
|
"block_info": block_info,
|
||||||
"block_size": len(block_data),
|
"file_offset": start_offset,
|
||||||
})
|
"block_size": len(block_data),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return changed_block_data
|
return changed_block_data
|
||||||
|
|
||||||
|
|
@ -210,7 +220,9 @@ class IncrementalLoader:
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def _get_existing_signature(self, session: AsyncSession, data_id: str) -> Optional[FileSignature]:
|
async def _get_existing_signature(
|
||||||
|
self, session: AsyncSession, data_id: str
|
||||||
|
) -> Optional[FileSignature]:
|
||||||
"""
|
"""
|
||||||
Get existing file signature from database
|
Get existing file signature from database
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,9 @@ async def ingest_data(
|
||||||
incremental_loader = IncrementalLoader()
|
incremental_loader = IncrementalLoader()
|
||||||
|
|
||||||
# Check if file needs incremental processing
|
# Check if file needs incremental processing
|
||||||
should_process, change_info = await incremental_loader.should_process_file(file, data_id)
|
should_process, change_info = await incremental_loader.should_process_file(
|
||||||
|
file, data_id
|
||||||
|
)
|
||||||
|
|
||||||
# Save updated file signature regardless of whether processing is needed
|
# Save updated file signature regardless of whether processing is needed
|
||||||
await incremental_loader.save_file_signature(file, data_id)
|
await incremental_loader.save_file_signature(file, data_id)
|
||||||
|
|
@ -155,7 +157,7 @@ async def ingest_data(
|
||||||
ext_metadata["incremental_processing"] = {
|
ext_metadata["incremental_processing"] = {
|
||||||
"should_process": should_process,
|
"should_process": should_process,
|
||||||
"change_info": change_info,
|
"change_info": change_info,
|
||||||
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat()))
|
"processing_timestamp": json.loads(json.dumps(datetime.now().isoformat())),
|
||||||
}
|
}
|
||||||
|
|
||||||
if data_point is not None:
|
if data_point is not None:
|
||||||
|
|
|
||||||
|
|
@ -51,12 +51,12 @@ def test_AudioDocument(mock_engine):
|
||||||
GROUND_TRUTH,
|
GROUND_TRUTH,
|
||||||
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -34,12 +34,12 @@ def test_ImageDocument(mock_engine):
|
||||||
GROUND_TRUTH,
|
GROUND_TRUTH,
|
||||||
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -36,12 +36,12 @@ def test_PdfDocument(mock_engine):
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
|
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -49,12 +49,12 @@ def test_TextDocument(mock_engine, input_file, chunk_size):
|
||||||
GROUND_TRUTH[input_file],
|
GROUND_TRUTH[input_file],
|
||||||
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
|
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
ground_truth["word_count"] == paragraph_data.chunk_size
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -79,32 +79,32 @@ def test_UnstructuredDocument(mock_engine):
|
||||||
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
|
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
|
||||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test DOCX
|
# Test DOCX
|
||||||
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
|
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
|
||||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_end" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_end != {paragraph_data.cut_type = }"
|
"sentence_end" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_end != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# TEST CSV
|
# TEST CSV
|
||||||
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
|
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
|
||||||
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
assert (
|
||||||
f"Read text doesn't match expected text: {paragraph_data.text}"
|
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
|
||||||
)
|
), f"Read text doesn't match expected text: {paragraph_data.text}"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test XLSX
|
# Test XLSX
|
||||||
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||||
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
|
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
|
||||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ async def check_graph_metrics_consistency_across_adapters(include_optional=False
|
||||||
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
||||||
|
|
||||||
for key, neo4j_value in neo4j_metrics.items():
|
for key, neo4j_value in neo4j_metrics.items():
|
||||||
assert networkx_metrics[key] == neo4j_value, (
|
assert (
|
||||||
f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
|
networkx_metrics[key] == neo4j_value
|
||||||
)
|
), f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,6 @@ async def assert_metrics(provider, include_optional=True):
|
||||||
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
||||||
|
|
||||||
for key, ground_truth_value in ground_truth_metrics.items():
|
for key, ground_truth_value in ground_truth_metrics.items():
|
||||||
assert metrics[key] == ground_truth_value, (
|
assert (
|
||||||
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
|
metrics[key] == ground_truth_value
|
||||||
)
|
), f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
|
||||||
|
|
|
||||||
|
|
@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
|
||||||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||||
# Get data entry from database based on hash contents
|
# Get data entry from database based on hash contents
|
||||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test deletion of data along with local files created by cognee
|
# Test deletion of data along with local files created by cognee
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert not os.path.exists(data.raw_data_location), (
|
assert not os.path.exists(
|
||||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||||
|
|
||||||
async with engine.get_async_session() as session:
|
async with engine.get_async_session() as session:
|
||||||
# Get data entry from database based on file path
|
# Get data entry from database based on file path
|
||||||
data = (
|
data = (
|
||||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||||
).one()
|
).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test local files not created by cognee won't get deleted
|
# Test local files not created by cognee won't get deleted
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert os.path.exists(data.raw_data_location), (
|
assert os.path.exists(
|
||||||
f"Data location doesn't exists: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exists: {data.raw_data_location}"
|
||||||
|
|
||||||
|
|
||||||
async def test_getting_of_documents(dataset_name_1):
|
async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||||
assert len(document_ids) == 1, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
len(document_ids) == 1
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||||
|
|
||||||
# Test getting of documents for search when no dataset is provided
|
# Test getting of documents for search when no dataset is provided
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id)
|
document_ids = await get_document_ids_for_user(user.id)
|
||||||
assert len(document_ids) == 2, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
len(document_ids) == 2
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,9 @@ async def test_deduplication():
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("data")
|
result = await relational_engine.get_all_data_from_table("data")
|
||||||
assert len(result) == 1, "More than one data entity was found."
|
assert len(result) == 1, "More than one data entity was found."
|
||||||
assert result[0]["name"] == "Natural_language_processing_copy", (
|
assert (
|
||||||
"Result name does not match expected value."
|
result[0]["name"] == "Natural_language_processing_copy"
|
||||||
)
|
), "Result name does not match expected value."
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("datasets")
|
result = await relational_engine.get_all_data_from_table("datasets")
|
||||||
assert len(result) == 2, "Unexpected number of datasets found."
|
assert len(result) == 2, "Unexpected number of datasets found."
|
||||||
|
|
@ -61,9 +61,9 @@ async def test_deduplication():
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("data")
|
result = await relational_engine.get_all_data_from_table("data")
|
||||||
assert len(result) == 1, "More than one data entity was found."
|
assert len(result) == 1, "More than one data entity was found."
|
||||||
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
|
assert (
|
||||||
"Content hash is not a part of file name."
|
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
|
||||||
)
|
), "Content hash is not a part of file name."
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
|
||||||
|
|
@ -92,9 +92,9 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
assert not os.path.exists(get_relational_engine().db_path), (
|
assert not os.path.exists(
|
||||||
"SQLite relational database is not empty"
|
get_relational_engine().db_path
|
||||||
)
|
), "SQLite relational database is not empty"
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,13 +103,13 @@ async def main():
|
||||||
node_name=["nonexistent"],
|
node_name=["nonexistent"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
isinstance(context_nonempty, str) and context_nonempty != ""
|
||||||
)
|
), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||||
|
|
||||||
assert context_empty == "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
context_empty == ""
|
||||||
)
|
), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
|
||||||
|
|
@ -107,13 +107,13 @@ async def main():
|
||||||
node_name=["nonexistent"],
|
node_name=["nonexistent"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
isinstance(context_nonempty, str) and context_nonempty != ""
|
||||||
)
|
), f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||||
|
|
||||||
assert context_empty == "", (
|
assert (
|
||||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
context_empty == ""
|
||||||
)
|
), f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
|
||||||
|
|
@ -23,28 +23,28 @@ async def test_local_file_deletion(data_text, file_location):
|
||||||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||||
# Get data entry from database based on hash contents
|
# Get data entry from database based on hash contents
|
||||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test deletion of data along with local files created by cognee
|
# Test deletion of data along with local files created by cognee
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert not os.path.exists(data.raw_data_location), (
|
assert not os.path.exists(
|
||||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||||
|
|
||||||
async with engine.get_async_session() as session:
|
async with engine.get_async_session() as session:
|
||||||
# Get data entry from database based on file path
|
# Get data entry from database based on file path
|
||||||
data = (
|
data = (
|
||||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||||
).one()
|
).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test local files not created by cognee won't get deleted
|
# Test local files not created by cognee won't get deleted
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert os.path.exists(data.raw_data_location), (
|
assert os.path.exists(
|
||||||
f"Data location doesn't exists: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exists: {data.raw_data_location}"
|
||||||
|
|
||||||
|
|
||||||
async def test_getting_of_documents(dataset_name_1):
|
async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
@ -53,16 +53,16 @@ async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||||
assert len(document_ids) == 1, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
len(document_ids) == 1
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||||
|
|
||||||
# Test getting of documents for search when no dataset is provided
|
# Test getting of documents for search when no dataset is provided
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id)
|
document_ids = await get_document_ids_for_user(user.id)
|
||||||
assert len(document_ids) == 2, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
len(document_ids) == 2
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
||||||
|
|
@ -112,9 +112,9 @@ async def relational_db_migration():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||||
|
|
||||||
assert len(distinct_node_names) == 12, (
|
assert (
|
||||||
f"Expected 12 distinct node references, found {len(distinct_node_names)}"
|
len(distinct_node_names) == 12
|
||||||
)
|
), f"Expected 12 distinct node references, found {len(distinct_node_names)}"
|
||||||
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
|
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
|
||||||
|
|
||||||
expected_edges = {
|
expected_edges = {
|
||||||
|
|
|
||||||
|
|
@ -29,54 +29,54 @@ async def main():
|
||||||
logging.info(edge_type_counts)
|
logging.info(edge_type_counts)
|
||||||
|
|
||||||
# Assert there is exactly one PdfDocument.
|
# Assert there is exactly one PdfDocument.
|
||||||
assert type_counts.get("PdfDocument", 0) == 1, (
|
assert (
|
||||||
f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
|
type_counts.get("PdfDocument", 0) == 1
|
||||||
)
|
), f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
|
||||||
|
|
||||||
# Assert there is exactly one TextDocument.
|
# Assert there is exactly one TextDocument.
|
||||||
assert type_counts.get("TextDocument", 0) == 1, (
|
assert (
|
||||||
f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
|
type_counts.get("TextDocument", 0) == 1
|
||||||
)
|
), f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
|
||||||
|
|
||||||
# Assert there are at least two DocumentChunk nodes.
|
# Assert there are at least two DocumentChunk nodes.
|
||||||
assert type_counts.get("DocumentChunk", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
|
type_counts.get("DocumentChunk", 0) >= 2
|
||||||
)
|
), f"Expected at least two DocumentChunk nodes, but found {type_counts.get('DocumentChunk', 0)}"
|
||||||
|
|
||||||
# Assert there is at least two TextSummary.
|
# Assert there is at least two TextSummary.
|
||||||
assert type_counts.get("TextSummary", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
|
type_counts.get("TextSummary", 0) >= 2
|
||||||
)
|
), f"Expected at least two TextSummary, but found {type_counts.get('TextSummary', 0)}"
|
||||||
|
|
||||||
# Assert there is at least one Entity.
|
# Assert there is at least one Entity.
|
||||||
assert type_counts.get("Entity", 0) > 0, (
|
assert (
|
||||||
f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
|
type_counts.get("Entity", 0) > 0
|
||||||
)
|
), f"Expected more than zero Entity nodes, but found {type_counts.get('Entity', 0)}"
|
||||||
|
|
||||||
# Assert there is at least one EntityType.
|
# Assert there is at least one EntityType.
|
||||||
assert type_counts.get("EntityType", 0) > 0, (
|
assert (
|
||||||
f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
|
type_counts.get("EntityType", 0) > 0
|
||||||
)
|
), f"Expected more than zero EntityType nodes, but found {type_counts.get('EntityType', 0)}"
|
||||||
|
|
||||||
# Assert that there are at least two 'is_part_of' edges.
|
# Assert that there are at least two 'is_part_of' edges.
|
||||||
assert edge_type_counts.get("is_part_of", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
|
edge_type_counts.get("is_part_of", 0) >= 2
|
||||||
)
|
), f"Expected at least two 'is_part_of' edges, but found {edge_type_counts.get('is_part_of', 0)}"
|
||||||
|
|
||||||
# Assert that there are at least two 'made_from' edges.
|
# Assert that there are at least two 'made_from' edges.
|
||||||
assert edge_type_counts.get("made_from", 0) >= 2, (
|
assert (
|
||||||
f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
|
edge_type_counts.get("made_from", 0) >= 2
|
||||||
)
|
), f"Expected at least two 'made_from' edges, but found {edge_type_counts.get('made_from', 0)}"
|
||||||
|
|
||||||
# Assert that there is at least one 'is_a' edge.
|
# Assert that there is at least one 'is_a' edge.
|
||||||
assert edge_type_counts.get("is_a", 0) >= 1, (
|
assert (
|
||||||
f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
|
edge_type_counts.get("is_a", 0) >= 1
|
||||||
)
|
), f"Expected at least one 'is_a' edge, but found {edge_type_counts.get('is_a', 0)}"
|
||||||
|
|
||||||
# Assert that there is at least one 'contains' edge.
|
# Assert that there is at least one 'contains' edge.
|
||||||
assert edge_type_counts.get("contains", 0) >= 1, (
|
assert (
|
||||||
f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
|
edge_type_counts.get("contains", 0) >= 1
|
||||||
)
|
), f"Expected at least one 'contains' edge, but found {edge_type_counts.get('contains', 0)}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,9 @@ async def main():
|
||||||
assert isinstance(context, str), f"{name}: Context should be a string"
|
assert isinstance(context, str), f"{name}: Context should be a string"
|
||||||
assert context.strip(), f"{name}: Context should not be empty"
|
assert context.strip(), f"{name}: Context should not be empty"
|
||||||
lower = context.lower()
|
lower = context.lower()
|
||||||
assert "germany" in lower or "netherlands" in lower, (
|
assert (
|
||||||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
"germany" in lower or "netherlands" in lower
|
||||||
)
|
), f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
||||||
|
|
||||||
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
||||||
query="Next to which country is Germany located?"
|
query="Next to which country is Germany located?"
|
||||||
|
|
@ -96,18 +96,18 @@ async def main():
|
||||||
distance = edge.attributes.get("vector_distance")
|
distance = edge.attributes.get("vector_distance")
|
||||||
node1_distance = edge.node1.attributes.get("vector_distance")
|
node1_distance = edge.node1.attributes.get("vector_distance")
|
||||||
node2_distance = edge.node2.attributes.get("vector_distance")
|
node2_distance = edge.node2.attributes.get("vector_distance")
|
||||||
assert isinstance(distance, float), (
|
assert isinstance(
|
||||||
f"{name}: vector_distance should be float, got {type(distance)}"
|
distance, float
|
||||||
)
|
), f"{name}: vector_distance should be float, got {type(distance)}"
|
||||||
assert 0 <= distance <= 1, (
|
assert (
|
||||||
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
0 <= distance <= 1
|
||||||
)
|
), f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
assert 0 <= node1_distance <= 1, (
|
assert (
|
||||||
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
0 <= node1_distance <= 1
|
||||||
)
|
), f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
assert 0 <= node2_distance <= 1, (
|
assert (
|
||||||
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
0 <= node2_distance <= 1
|
||||||
)
|
), f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
||||||
|
|
||||||
completion_gk = await cognee.search(
|
completion_gk = await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
|
@ -137,9 +137,9 @@ async def main():
|
||||||
text = completion[0]
|
text = completion[0]
|
||||||
assert isinstance(text, str), f"{name}: element should be a string"
|
assert isinstance(text, str), f"{name}: element should be a string"
|
||||||
assert text.strip(), f"{name}: string should not be empty"
|
assert text.strip(), f"{name}: string should not be empty"
|
||||||
assert "netherlands" in text.lower(), (
|
assert (
|
||||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
"netherlands" in text.lower()
|
||||||
)
|
), f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,12 @@ async def test_answer_generation():
|
||||||
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
|
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
|
||||||
|
|
||||||
assert len(answers) == len(qa_pairs)
|
assert len(answers) == len(qa_pairs)
|
||||||
assert answers[0]["question"] == qa_pairs[0]["question"], (
|
assert (
|
||||||
"AnswerGeneratorExecutor is passing the question incorrectly"
|
answers[0]["question"] == qa_pairs[0]["question"]
|
||||||
)
|
), "AnswerGeneratorExecutor is passing the question incorrectly"
|
||||||
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
|
assert (
|
||||||
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
answers[0]["golden_answer"] == qa_pairs[0]["answer"]
|
||||||
)
|
), "AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
||||||
assert answers[0]["answer"] == "Mocked answer", (
|
assert (
|
||||||
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
answers[0]["answer"] == "Mocked answer"
|
||||||
)
|
), "AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,9 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
||||||
|
|
||||||
corpus_list, qa_pairs = result
|
corpus_list, qa_pairs = result
|
||||||
assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list."
|
assert isinstance(corpus_list, list), f"{AdapterClass.__name__} corpus_list is not a list."
|
||||||
assert isinstance(qa_pairs, list), (
|
assert isinstance(
|
||||||
f"{AdapterClass.__name__} question_answer_pairs is not a list."
|
qa_pairs, list
|
||||||
)
|
), f"{AdapterClass.__name__} question_answer_pairs is not a list."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES)
|
@pytest.mark.parametrize("AdapterClass", ADAPTER_CLASSES)
|
||||||
|
|
@ -71,9 +71,9 @@ def test_adapter_returns_some_content(AdapterClass):
|
||||||
# We don't know how large the dataset is, but we expect at least 1 item
|
# We don't know how large the dataset is, but we expect at least 1 item
|
||||||
assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list."
|
assert len(corpus_list) > 0, f"{AdapterClass.__name__} returned an empty corpus_list."
|
||||||
assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs."
|
assert len(qa_pairs) > 0, f"{AdapterClass.__name__} returned an empty question_answer_pairs."
|
||||||
assert len(qa_pairs) <= limit, (
|
assert (
|
||||||
f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
|
len(qa_pairs) <= limit
|
||||||
)
|
), f"{AdapterClass.__name__} returned more QA items than requested limit={limit}."
|
||||||
|
|
||||||
for item in qa_pairs:
|
for item in qa_pairs:
|
||||||
assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair."
|
assert "question" in item, f"{AdapterClass.__name__} missing 'question' key in QA pair."
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ def test_corpus_builder_load_corpus(benchmark):
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||||
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
||||||
assert len(questions) <= 2, (
|
assert (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
len(questions) <= 2
|
||||||
)
|
), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -24,6 +24,6 @@ async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
||||||
limit = 2
|
limit = 2
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
questions = await corpus_builder.build_corpus(limit=limit)
|
questions = await corpus_builder.build_corpus(limit=limit)
|
||||||
assert len(questions) <= 2, (
|
assert (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
len(questions) <= 2
|
||||||
)
|
), f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
|
|
|
||||||
|
|
@ -52,14 +52,14 @@ def test_metrics(metrics, actual, expected, expected_exact_score, expected_f1_ra
|
||||||
test_case = MockTestCase(actual, expected)
|
test_case = MockTestCase(actual, expected)
|
||||||
|
|
||||||
exact_match_score = metrics["exact_match"].measure(test_case)
|
exact_match_score = metrics["exact_match"].measure(test_case)
|
||||||
assert exact_match_score == expected_exact_score, (
|
assert (
|
||||||
f"Exact match failed for '{actual}' vs '{expected}'"
|
exact_match_score == expected_exact_score
|
||||||
)
|
), f"Exact match failed for '{actual}' vs '{expected}'"
|
||||||
|
|
||||||
f1_score = metrics["f1"].measure(test_case)
|
f1_score = metrics["f1"].measure(test_case)
|
||||||
assert expected_f1_range[0] <= f1_score <= expected_f1_range[1], (
|
assert (
|
||||||
f"F1 score failed for '{actual}' vs '{expected}'"
|
expected_f1_range[0] <= f1_score <= expected_f1_range[1]
|
||||||
)
|
), f"F1 score failed for '{actual}' vs '{expected}'"
|
||||||
|
|
||||||
|
|
||||||
class TestBootstrapCI(unittest.TestCase):
|
class TestBootstrapCI(unittest.TestCase):
|
||||||
|
|
|
||||||
|
|
@ -157,15 +157,15 @@ def test_rate_limit_60_per_minute():
|
||||||
if len(failures) > 0:
|
if len(failures) > 0:
|
||||||
first_failure_idx = int(failures[0].split()[1])
|
first_failure_idx = int(failures[0].split()[1])
|
||||||
print(f"First failure occurred at request index: {first_failure_idx}")
|
print(f"First failure occurred at request index: {first_failure_idx}")
|
||||||
assert 58 <= first_failure_idx <= 62, (
|
assert (
|
||||||
f"Expected first failure around request #60, got #{first_failure_idx}"
|
58 <= first_failure_idx <= 62
|
||||||
)
|
), f"Expected first failure around request #60, got #{first_failure_idx}"
|
||||||
|
|
||||||
# Calculate requests per minute
|
# Calculate requests per minute
|
||||||
rate_per_minute = len(successes)
|
rate_per_minute = len(successes)
|
||||||
print(f"Rate: {rate_per_minute} requests per minute")
|
print(f"Rate: {rate_per_minute} requests per minute")
|
||||||
|
|
||||||
# Verify the rate is close to our target of 60 requests per minute
|
# Verify the rate is close to our target of 60 requests per minute
|
||||||
assert 58 <= rate_per_minute <= 62, (
|
assert (
|
||||||
f"Expected rate of ~60 requests per minute, got {rate_per_minute}"
|
58 <= rate_per_minute <= 62
|
||||||
)
|
), f"Expected rate of ~60 requests per minute, got {rate_per_minute}"
|
||||||
|
|
|
||||||
|
|
@ -110,9 +110,9 @@ def test_sync_retry():
|
||||||
print(f"Number of attempts: {test_function_sync.counter}")
|
print(f"Number of attempts: {test_function_sync.counter}")
|
||||||
|
|
||||||
# The function should succeed on the 3rd attempt (after 2 failures)
|
# The function should succeed on the 3rd attempt (after 2 failures)
|
||||||
assert test_function_sync.counter == 3, (
|
assert (
|
||||||
f"Expected 3 attempts, got {test_function_sync.counter}"
|
test_function_sync.counter == 3
|
||||||
)
|
), f"Expected 3 attempts, got {test_function_sync.counter}"
|
||||||
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
||||||
|
|
||||||
print("✅ PASS: Synchronous retry mechanism is working correctly")
|
print("✅ PASS: Synchronous retry mechanism is working correctly")
|
||||||
|
|
@ -143,9 +143,9 @@ async def test_async_retry():
|
||||||
print(f"Number of attempts: {test_function_async.counter}")
|
print(f"Number of attempts: {test_function_async.counter}")
|
||||||
|
|
||||||
# The function should succeed on the 3rd attempt (after 2 failures)
|
# The function should succeed on the 3rd attempt (after 2 failures)
|
||||||
assert test_function_async.counter == 3, (
|
assert (
|
||||||
f"Expected 3 attempts, got {test_function_async.counter}"
|
test_function_async.counter == 3
|
||||||
)
|
), f"Expected 3 attempts, got {test_function_async.counter}"
|
||||||
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
||||||
|
|
||||||
print("✅ PASS: Asynchronous retry mechanism is working correctly")
|
print("✅ PASS: Asynchronous retry mechanism is working correctly")
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Canva?")
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_extension_context_complex(self):
|
async def test_graph_completion_extension_context_complex(self):
|
||||||
|
|
@ -136,9 +136,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||||
|
|
@ -167,9 +167,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Canva?")
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_cot_context_complex(self):
|
async def test_graph_completion_cot_context_complex(self):
|
||||||
|
|
@ -134,9 +134,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||||
|
|
@ -165,9 +165,9 @@ class TestGraphCompletionRetriever:
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(
|
||||||
"Answer must contain only non-empty strings"
|
isinstance(item, str) and item.strip() for item in answer
|
||||||
)
|
), "Answer must contain only non-empty strings"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,9 @@ max_chunk_size_vals = [512, 1024, 4096]
|
||||||
def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs):
|
def test_chunk_by_paragraph_isomorphism(input_text, max_chunk_size, batch_paragraphs):
|
||||||
chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs)
|
chunks = chunk_by_paragraph(input_text, max_chunk_size, batch_paragraphs)
|
||||||
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
|
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -54,9 +54,9 @@ def test_paragraph_chunk_length(input_text, max_chunk_size, batch_paragraphs):
|
||||||
)
|
)
|
||||||
|
|
||||||
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
||||||
assert np.all(chunk_lengths <= max_chunk_size), (
|
assert np.all(
|
||||||
f"{max_chunk_size = }: {larger_chunks} are too large"
|
chunk_lengths <= max_chunk_size
|
||||||
)
|
), f"{max_chunk_size = }: {larger_chunks} are too large"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -76,6 +76,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, max_chunk_size, batch_pa
|
||||||
batch_paragraphs=batch_paragraphs,
|
batch_paragraphs=batch_paragraphs,
|
||||||
)
|
)
|
||||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||||
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
assert np.all(
|
||||||
f"{chunk_indices = } are not monotonically increasing"
|
chunk_indices == np.arange(len(chunk_indices))
|
||||||
)
|
), f"{chunk_indices = } are not monotonically increasing"
|
||||||
|
|
|
||||||
|
|
@ -71,9 +71,9 @@ def run_chunking_test(test_text, expected_chunks, mock_engine):
|
||||||
|
|
||||||
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
|
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
|
||||||
for key in ["text", "chunk_size", "cut_type"]:
|
for key in ["text", "chunk_size", "cut_type"]:
|
||||||
assert chunk[key] == expected_chunks_item[key], (
|
assert (
|
||||||
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
chunk[key] == expected_chunks_item[key]
|
||||||
)
|
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
||||||
|
|
||||||
|
|
||||||
def test_chunking_whole_text():
|
def test_chunking_whole_text():
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ maximum_length_vals = [None, 16, 64]
|
||||||
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
|
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
|
||||||
chunks = chunk_by_sentence(input_text, maximum_length)
|
chunks = chunk_by_sentence(input_text, maximum_length)
|
||||||
reconstructed_text = "".join([chunk[1] for chunk in chunks])
|
reconstructed_text = "".join([chunk[1] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -40,9 +40,9 @@ def test_paragraph_chunk_length(input_text, maximum_length):
|
||||||
)
|
)
|
||||||
|
|
||||||
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
|
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
|
||||||
assert np.all(chunk_lengths <= maximum_length), (
|
assert np.all(
|
||||||
f"{maximum_length = }: {larger_chunks} are too large"
|
chunk_lengths <= maximum_length
|
||||||
)
|
), f"{maximum_length = }: {larger_chunks} are too large"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS, INPUT_TE
|
||||||
def test_chunk_by_word_isomorphism(input_text):
|
def test_chunk_by_word_isomorphism(input_text):
|
||||||
chunks = chunk_by_word(input_text)
|
chunks = chunk_by_word(input_text)
|
||||||
reconstructed_text = "".join([chunk[0] for chunk in chunks])
|
reconstructed_text = "".join([chunk[0] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,12 @@ This is the end of the modified content.
|
||||||
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
||||||
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
||||||
|
|
||||||
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks")
|
print(
|
||||||
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks")
|
f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
|
||||||
# Compare signatures to find changes
|
# Compare signatures to find changes
|
||||||
print("\n2. Comparing signatures to detect changes...")
|
print("\n2. Comparing signatures to detect changes...")
|
||||||
|
|
@ -77,7 +81,9 @@ This is the end of the modified content.
|
||||||
|
|
||||||
print(f" Changed blocks: {changed_blocks}")
|
print(f" Changed blocks: {changed_blocks}")
|
||||||
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
||||||
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}")
|
print(
|
||||||
|
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Generate delta
|
# Generate delta
|
||||||
print("\n3. Generating delta for changed content...")
|
print("\n3. Generating delta for changed content...")
|
||||||
|
|
@ -145,7 +151,7 @@ async def demonstrate_with_cognee():
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# Create a temporary file
|
# Create a temporary file
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||||
f.write("Initial content for Cognee processing.")
|
f.write("Initial content for Cognee processing.")
|
||||||
temp_file_path = f.name
|
temp_file_path = f.name
|
||||||
|
|
||||||
|
|
@ -158,7 +164,7 @@ async def demonstrate_with_cognee():
|
||||||
print(" ✅ File added successfully")
|
print(" ✅ File added successfully")
|
||||||
|
|
||||||
# Modify the file
|
# Modify the file
|
||||||
with open(temp_file_path, 'w') as f:
|
with open(temp_file_path, "w") as f:
|
||||||
f.write("Modified content for Cognee processing with additional text.")
|
f.write("Modified content for Cognee processing with additional text.")
|
||||||
|
|
||||||
print("2. Adding modified version of the same file...")
|
print("2. Adding modified version of the same file...")
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@ Simple test for incremental loading functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, '.')
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from cognee.modules.ingestion.incremental import BlockHashService
|
from cognee.modules.ingestion.incremental import BlockHashService
|
||||||
|
|
@ -47,8 +48,12 @@ Line 6: NEW - This is additional content"""
|
||||||
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
initial_signature = block_service.generate_signature(initial_file, "test_file.txt")
|
||||||
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
modified_signature = block_service.generate_signature(modified_file, "test_file.txt")
|
||||||
|
|
||||||
print(f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks")
|
print(
|
||||||
print(f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks")
|
f" Initial file: {initial_signature.file_size} bytes, {initial_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Modified file: {modified_signature.file_size} bytes, {modified_signature.total_blocks} blocks"
|
||||||
|
)
|
||||||
|
|
||||||
# Compare signatures to find changes
|
# Compare signatures to find changes
|
||||||
print("\n2. Comparing signatures to detect changes...")
|
print("\n2. Comparing signatures to detect changes...")
|
||||||
|
|
@ -58,7 +63,9 @@ Line 6: NEW - This is additional content"""
|
||||||
|
|
||||||
print(f" Changed blocks: {changed_blocks}")
|
print(f" Changed blocks: {changed_blocks}")
|
||||||
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
print(f" Compression ratio: {change_stats['compression_ratio']:.2%}")
|
||||||
print(f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}")
|
print(
|
||||||
|
f" Total blocks changed: {change_stats['changed_blocks']} out of {change_stats['total_old_blocks']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Generate delta
|
# Generate delta
|
||||||
print("\n3. Generating delta for changed content...")
|
print("\n3. Generating delta for changed content...")
|
||||||
|
|
|
||||||
|
|
@ -144,13 +144,13 @@ class TestIncrementalLoader:
|
||||||
block_size=10,
|
block_size=10,
|
||||||
strong_len=8,
|
strong_len=8,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
signature_data=b"signature"
|
signature_data=b"signature",
|
||||||
)
|
)
|
||||||
|
|
||||||
change_info = {
|
change_info = {
|
||||||
"type": "incremental_changes",
|
"type": "incremental_changes",
|
||||||
"changed_blocks": [1], # Only middle block changed
|
"changed_blocks": [1], # Only middle block changed
|
||||||
"new_signature": signature
|
"new_signature": signature,
|
||||||
}
|
}
|
||||||
|
|
||||||
# This would normally be called after should_process_file
|
# This would normally be called after should_process_file
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue