feat: csv ingestion & chunking (#1574)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> Create a dedicated CSV ingestion path with a custom loader and custom chunker that preserves row-column relationships in the produced chunks. #1348 ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [x] Breaking change (fix or feature that would cause existing functionality to change) - [x] Documentation update - [x] Code refactoring - [x] Performance improvement - [x] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
2f2a4487f0
18 changed files with 403 additions and 15 deletions
|
|
@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
|
|
@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
chunker_class = LangchainChunker
|
||||
except ImportError:
|
||||
fmt.warning("LangchainChunker not available, using TextChunker")
|
||||
elif args.chunker == "CsvChunker":
|
||||
try:
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
|
||||
chunker_class = CsvChunker
|
||||
except ImportError:
|
||||
fmt.warning("CsvChunker not available, using TextChunker")
|
||||
|
||||
result = await cognee.cognify(
|
||||
datasets=datasets,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
|
|||
]
|
||||
|
||||
# Chunker choices
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
|
||||
|
||||
# Output format choices
|
||||
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type
|
|||
file_type = Type("text/plain", "txt")
|
||||
return file_type
|
||||
|
||||
if ext in [".csv"]:
|
||||
file_type = Type("text/csv", "csv")
|
||||
return file_type
|
||||
|
||||
file_type = filetype.guess(file)
|
||||
|
||||
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class LoaderEngine:
|
|||
"pypdf_loader",
|
||||
"image_loader",
|
||||
"audio_loader",
|
||||
"csv_loader",
|
||||
"unstructured_loader",
|
||||
"advanced_pdf_loader",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@
|
|||
from .text_loader import TextLoader
|
||||
from .audio_loader import AudioLoader
|
||||
from .image_loader import ImageLoader
|
||||
from .csv_loader import CsvLoader
|
||||
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]
|
||||
|
|
|
|||
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
from typing import List
|
||||
import csv
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
|
||||
|
||||
class CsvLoader(LoaderInterface):
|
||||
"""
|
||||
Core CSV file loader that handles basic CSV file formats.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return [
|
||||
"csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""Supported MIME types for text content."""
|
||||
return [
|
||||
"text/csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
"""Unique identifier for this loader."""
|
||||
return "csv_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
extension: File extension
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
True if file can be handled, False otherwise
|
||||
"""
|
||||
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
|
||||
"""
|
||||
Load and process the csv file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load
|
||||
encoding: Text encoding to use (default: utf-8)
|
||||
**kwargs: Additional configuration (unused)
|
||||
|
||||
Returns:
|
||||
LoaderResult containing the file content and metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||
OSError: If file cannot be read
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
# Name ingested file of current loader based on original file content hash
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
row_texts = []
|
||||
row_index = 1
|
||||
|
||||
with open(file_path, "r", encoding=encoding, newline="") as file:
|
||||
reader = csv.DictReader(file)
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
row_texts.append(f"Row {row_index}:\n{row_text}\n")
|
||||
row_index += 1
|
||||
|
||||
content = "\n".join(row_texts)
|
||||
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, content)
|
||||
|
||||
return full_file_path
|
||||
|
|
@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
|
|||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
|
||||
return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
|
|
@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
|
|||
return [
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"text/xml",
|
||||
"application/xml",
|
||||
|
|
|
|||
|
|
@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
|
|||
if value is None:
|
||||
return ""
|
||||
return str(value).replace("\xa0", " ").strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = AdvancedPdfLoader()
|
||||
asyncio.run(
|
||||
loader.load(
|
||||
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.infrastructure.loaders.external import PyPdfLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
|
||||
|
||||
# Registry for loader implementations
|
||||
supported_loaders = {
|
||||
|
|
@ -7,6 +7,7 @@ supported_loaders = {
|
|||
TextLoader.loader_name: TextLoader,
|
||||
ImageLoader.loader_name: ImageLoader,
|
||||
AudioLoader.loader_name: AudioLoader,
|
||||
CsvLoader.loader_name: CsvLoader,
|
||||
}
|
||||
|
||||
# Try adding optional loaders
|
||||
|
|
|
|||
35
cognee/modules/chunking/CsvChunker.py
Normal file
35
cognee/modules/chunking/CsvChunker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_row
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CsvChunker(Chunker):
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
if content_text is None:
|
||||
continue
|
||||
|
||||
for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
|
||||
if chunk_data["chunk_size"] <= self.max_chunk_size:
|
||||
yield DocumentChunk(
|
||||
id=chunk_data["chunk_id"],
|
||||
text=chunk_data["text"],
|
||||
chunk_size=chunk_data["chunk_size"],
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
contains=[],
|
||||
metadata={
|
||||
"index_fields": ["text"],
|
||||
},
|
||||
)
|
||||
self.chunk_index += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
|
||||
)
|
||||
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import io
|
||||
import csv
|
||||
from typing import Type
|
||||
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class CsvDocument(Document):
|
||||
type: str = "csv"
|
||||
mime_type: str = "text/csv"
|
||||
|
||||
async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
|
||||
async def get_text():
|
||||
async with open_data_file(
|
||||
self.raw_data_location, mode="r", encoding="utf-8", newline=""
|
||||
) as file:
|
||||
content = file.read()
|
||||
file_like_obj = io.StringIO(content)
|
||||
reader = csv.DictReader(file_like_obj)
|
||||
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
if not row_text.strip():
|
||||
break
|
||||
yield row_text
|
||||
|
||||
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
|
||||
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
@ -4,3 +4,4 @@ from .TextDocument import TextDocument
|
|||
from .ImageDocument import ImageDocument
|
||||
from .AudioDocument import AudioDocument
|
||||
from .UnstructuredDocument import UnstructuredDocument
|
||||
from .CsvDocument import CsvDocument
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .chunk_by_word import chunk_by_word
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
from .chunk_by_paragraph import chunk_by_paragraph
|
||||
from .chunk_by_row import chunk_by_row
|
||||
from .remove_disconnected_chunks import remove_disconnected_chunks
|
||||
|
|
|
|||
94
cognee/tasks/chunks/chunk_by_row.py
Normal file
94
cognee/tasks/chunks/chunk_by_row.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
from typing import Any, Dict, Iterator
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
|
||||
|
||||
def _get_pair_size(pair_text: str) -> int:
|
||||
"""
|
||||
Calculate the size of a given text in terms of tokens.
|
||||
|
||||
If an embedding engine's tokenizer is available, count the tokens for the provided word.
|
||||
If the tokenizer is not available, assume the word counts as one token.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- pair_text (str): The key:value pair text for which the token size is to be calculated.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- int: The number of tokens representing the text, typically an integer, depending
|
||||
on the tokenizer's output.
|
||||
"""
|
||||
embedding_engine = get_embedding_engine()
|
||||
if embedding_engine.tokenizer:
|
||||
return embedding_engine.tokenizer.count_tokens(pair_text)
|
||||
else:
|
||||
return 3
|
||||
|
||||
|
||||
def chunk_by_row(
|
||||
data: str,
|
||||
max_chunk_size,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Chunk the input text by row while enabling exact text reconstruction.
|
||||
|
||||
This function divides the given text data into smaller chunks on a line-by-line basis,
|
||||
ensuring that the size of each chunk is less than or equal to the specified maximum
|
||||
chunk size. It guarantees that when the generated chunks are concatenated, they
|
||||
reproduce the original text accurately. The tokenization process is handled by
|
||||
adapters compatible with the vector engine's embedding model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- data (str): The input text to be chunked.
|
||||
- max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or
|
||||
words.
|
||||
"""
|
||||
current_chunk_list = []
|
||||
chunk_index = 0
|
||||
current_chunk_size = 0
|
||||
|
||||
lines = data.split("\n\n")
|
||||
for line in lines:
|
||||
pairs_text = line.split(", ")
|
||||
|
||||
for pair_text in pairs_text:
|
||||
pair_size = _get_pair_size(pair_text)
|
||||
if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size):
|
||||
# Yield current cut chunk
|
||||
current_chunk = ", ".join(current_chunk_list)
|
||||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"chunk_size": current_chunk_size,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"chunk_index": chunk_index,
|
||||
"cut_type": "row_cut",
|
||||
}
|
||||
|
||||
yield chunk_dict
|
||||
|
||||
# Start new chunk with current pair text
|
||||
current_chunk_list = []
|
||||
current_chunk_size = 0
|
||||
chunk_index += 1
|
||||
|
||||
current_chunk_list.append(pair_text)
|
||||
current_chunk_size += pair_size
|
||||
|
||||
# Yield row chunk
|
||||
current_chunk = ", ".join(current_chunk_list)
|
||||
if current_chunk:
|
||||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"chunk_size": current_chunk_size,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"chunk_index": chunk_index,
|
||||
"cut_type": "row_end",
|
||||
}
|
||||
|
||||
yield chunk_dict
|
||||
|
|
@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import (
|
|||
ImageDocument,
|
||||
TextDocument,
|
||||
UnstructuredDocument,
|
||||
CsvDocument,
|
||||
)
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.engine.utils.generate_node_id import generate_node_id
|
||||
|
|
@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError
|
|||
EXTENSION_TO_DOCUMENT_CLASS = {
|
||||
"pdf": PdfDocument, # Text documents
|
||||
"txt": TextDocument,
|
||||
"csv": CsvDocument,
|
||||
"docx": UnstructuredDocument,
|
||||
"doc": UnstructuredDocument,
|
||||
"odt": UnstructuredDocument,
|
||||
|
|
|
|||
70
cognee/tests/integration/documents/CsvDocument_test.py
Normal file
70
cognee/tests/integration/documents/CsvDocument_test.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
import pathlib
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument
|
||||
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
|
||||
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
|
||||
|
||||
chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row")
|
||||
|
||||
|
||||
GROUND_TRUTH = {
|
||||
"chunk_size_10": [
|
||||
{"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0},
|
||||
{"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1},
|
||||
{"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2},
|
||||
{"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3},
|
||||
],
|
||||
"chunk_size_128": [
|
||||
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0},
|
||||
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_file,chunk_size",
|
||||
[("example_with_header.csv", 10), ("example_with_header.csv", 128)],
|
||||
)
|
||||
@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine)
|
||||
@pytest.mark.asyncio
|
||||
async def test_CsvDocument(mock_engine, input_file, chunk_size):
|
||||
# Define file paths of test data
|
||||
csv_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent.parent.parent,
|
||||
"test_data",
|
||||
input_file,
|
||||
)
|
||||
|
||||
# Define test documents
|
||||
csv_document = CsvDocument(
|
||||
id=uuid.uuid4(),
|
||||
name="example_with_header.csv",
|
||||
raw_data_location=csv_file_path,
|
||||
external_metadata="",
|
||||
mime_type="text/csv",
|
||||
)
|
||||
|
||||
# TEST CSV
|
||||
ground_truth_key = f"chunk_size_{chunk_size}"
|
||||
async for ground_truth, row_data in async_gen_zip(
|
||||
GROUND_TRUTH[ground_truth_key],
|
||||
csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size),
|
||||
):
|
||||
assert ground_truth["token_count"] == row_data.chunk_size, (
|
||||
f'{ground_truth["token_count"] = } != {row_data.chunk_size = }'
|
||||
)
|
||||
assert ground_truth["len_text"] == len(row_data.text), (
|
||||
f'{ground_truth["len_text"] = } != {len(row_data.text) = }'
|
||||
)
|
||||
assert ground_truth["cut_type"] == row_data.cut_type, (
|
||||
f'{ground_truth["cut_type"] = } != {row_data.cut_type = }'
|
||||
)
|
||||
assert ground_truth["chunk_index"] == row_data.chunk_index, (
|
||||
f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }'
|
||||
)
|
||||
3
cognee/tests/test_data/example_with_header.csv
Normal file
3
cognee/tests/test_data/example_with_header.csv
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
id,name,age,city,country
|
||||
1,Eric,30,Beijing,China
|
||||
2,Joe,35,Berlin,Germany
|
||||
|
52
cognee/tests/unit/processing/chunks/chunk_by_row_test.py
Normal file
52
cognee/tests/unit/processing/chunks/chunk_by_row_test.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.tasks.chunks import chunk_by_row
|
||||
|
||||
INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA"
|
||||
max_chunk_size_vals = [8, 32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_chunk_by_row_isomorphism(input_text, max_chunk_size):
|
||||
chunks = chunk_by_row(input_text, max_chunk_size)
|
||||
reconstructed_text = ", ".join([chunk["text"] for chunk in chunks])
|
||||
assert reconstructed_text == input_text, (
|
||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_row_chunk_length(input_text, max_chunk_size):
|
||||
chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size))
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
chunk_lengths = np.array(
|
||||
[embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks]
|
||||
)
|
||||
|
||||
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
||||
assert np.all(chunk_lengths <= max_chunk_size), (
|
||||
f"{max_chunk_size = }: {larger_chunks} are too large"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size):
|
||||
chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)
|
||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
||||
f"{chunk_indices = } are not monotonically increasing"
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue