From 742866b4c9f1d4aa53ab60fb54b79474fbfea0d2 Mon Sep 17 00:00:00 2001 From: EricXiao Date: Wed, 22 Oct 2025 16:56:46 +0800 Subject: [PATCH] feat: csv ingestion loader & chunk Signed-off-by: EricXiao --- cognee/cli/commands/cognify_command.py | 9 +- cognee/cli/config.py | 2 +- .../files/utils/guess_file_type.py | 43 +++++ .../files/utils/is_csv_content.py | 181 ++++++++++++++++++ cognee/infrastructure/loaders/LoaderEngine.py | 1 + .../infrastructure/loaders/core/__init__.py | 3 +- .../infrastructure/loaders/core/csv_loader.py | 93 +++++++++ .../loaders/core/text_loader.py | 3 +- .../loaders/supported_loaders.py | 3 +- cognee/modules/chunking/CsvChunker.py | 35 ++++ .../processing/document_types/CsvDocument.py | 33 ++++ .../processing/document_types/__init__.py | 1 + cognee/tasks/chunks/__init__.py | 1 + cognee/tasks/chunks/chunk_by_row.py | 94 +++++++++ cognee/tasks/documents/classify_documents.py | 2 + .../integration/documents/CsvDocument_test.py | 70 +++++++ .../tests/test_data/example_with_header.csv | 3 + .../processing/chunks/chunk_by_row_test.py | 52 +++++ 18 files changed, 623 insertions(+), 6 deletions(-) create mode 100644 cognee/infrastructure/files/utils/is_csv_content.py create mode 100644 cognee/infrastructure/loaders/core/csv_loader.py create mode 100644 cognee/modules/chunking/CsvChunker.py create mode 100644 cognee/modules/data/processing/document_types/CsvDocument.py create mode 100644 cognee/tasks/chunks/chunk_by_row.py create mode 100644 cognee/tests/integration/documents/CsvDocument_test.py create mode 100644 cognee/tests/test_data/example_with_header.csv create mode 100644 cognee/tests/unit/processing/chunks/chunk_by_row_test.py diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py index 16eaf0454..b89c1f70e 100644 --- a/cognee/cli/commands/cognify_command.py +++ b/cognee/cli/commands/cognify_command.py @@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin Processing Pipeline: 1. **Document Classification**: Identifies document types and structures -2. **Permission Validation**: Ensures user has processing rights +2. **Permission Validation**: Ensures user has processing rights 3. **Text Chunking**: Breaks content into semantically meaningful segments 4. **Entity Extraction**: Identifies key concepts, people, places, organizations 5. **Relationship Detection**: Discovers connections between entities @@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge chunker_class = LangchainChunker except ImportError: fmt.warning("LangchainChunker not available, using TextChunker") + elif args.chunker == "CsvChunker": + try: + from cognee.modules.chunking.CsvChunker import CsvChunker + + chunker_class = CsvChunker + except ImportError: + fmt.warning("CsvChunker not available, using TextChunker") result = await cognee.cognify( datasets=datasets, diff --git a/cognee/cli/config.py b/cognee/cli/config.py index d016608c1..082adbaec 100644 --- a/cognee/cli/config.py +++ b/cognee/cli/config.py @@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [ ] # Chunker choices -CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"] +CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"] # Output format choices OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"] diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py index dcdd68cad..10f59a400 100644 --- a/cognee/infrastructure/files/utils/guess_file_type.py +++ b/cognee/infrastructure/files/utils/guess_file_type.py @@ -1,6 +1,8 @@ from typing import BinaryIO import filetype + from .is_text_content import is_text_content +from .is_csv_content import is_csv_content class FileTypeException(Exception): @@ -134,3 +136,44 @@ def guess_file_type(file: BinaryIO) -> filetype.Type: raise FileTypeException(f"Unknown file detected: {file.name}.") return file_type + + +class CsvFileType(filetype.Type): + """ + Match CSV file types based on MIME type and extension. + + Public methods: + - match + + Instance variables: + - MIME: The MIME type of the CSV. + - EXTENSION: The file extension of the CSV. + """ + + MIME = "text/csv" + EXTENSION = "csv" + + def __init__(self): + super().__init__(mime=self.MIME, extension=self.EXTENSION) + + def match(self, buf): + """ + Determine if the given buffer contains csv content. + + Parameters: + ----------- + + - buf: The buffer to check for csv content. + + Returns: + -------- + + Returns True if the buffer is identified as csv content, otherwise False. + """ + + return is_csv_content(buf) + + +csv_file_type = CsvFileType() + +filetype.add_type(csv_file_type) diff --git a/cognee/infrastructure/files/utils/is_csv_content.py b/cognee/infrastructure/files/utils/is_csv_content.py new file mode 100644 index 000000000..07b7ea69b --- /dev/null +++ b/cognee/infrastructure/files/utils/is_csv_content.py @@ -0,0 +1,181 @@ +import csv +from collections import Counter + + +def is_csv_content(content): + """ + Heuristically determine whether a bytes-like object is CSV text. + + Strategy (fail-fast and cheap to expensive): + 1) Decode: Try a small ordered list of common encodings with strict errors. + 2) Line sampling: require >= 2 non-empty lines; sample up to 50 lines. + 3) Delimiter detection: + - Prefer csv.Sniffer() with common delimiters. + - Fallback to a lightweight consistency heuristic. + 4) Lightweight parse check: + - Parse a few lines with the delimiter. + - Ensure at least 2 valid rows and relatively stable column counts. + + Returns: + bool: True if the buffer looks like CSV; False otherwise. + """ + try: + encoding_list = [ + "utf-8", + "utf-8-sig", + "utf-32-le", + "utf-32-be", + "utf-16-le", + "utf-16-be", + "gb18030", + "shift_jis", + "cp949", + "cp1252", + "iso-8859-1", + ] + + # Try to decode strictly—if decoding fails for all encodings, it's not text/CSV. + text = None + for enc in encoding_list: + try: + text = content.decode(enc, errors="strict") + break + except UnicodeDecodeError: + continue + if text is None: + return False + + # Reject empty/whitespace-only payloads. + stripped = text.strip() + if not stripped: + return False + + # Split into logical lines and drop empty ones. Require at least two lines. + lines = [ln for ln in text.splitlines() if ln.strip()] + if len(lines) < 2: + return False + + # Take a small sample to keep sniffing cheap and predictable. + sample_lines = lines[:50] + + # Detect delimiter using csv.Sniffer first; if that fails, use our heuristic. + delimiter = _sniff_delimiter(sample_lines) or _heuristic_delimiter(sample_lines) + if not delimiter: + return False + + # Finally, do a lightweight parse sanity check with the chosen delimiter. + return _lightweight_parse_check(sample_lines, delimiter) + except Exception: + return False + + +def _sniff_delimiter(lines): + """ + Try Python's built-in csv.Sniffer on a sample. + + Args: + lines (list[str]): Sample lines (already decoded). + + Returns: + str | None: The detected delimiter if sniffing succeeds; otherwise None. + """ + # Join up to 50 lines to form the sample string Sniffer will inspect. + sample = "\n".join(lines[:50]) + try: + dialect = csv.Sniffer().sniff(sample, delimiters=",\t;|") + return dialect.delimiter + except Exception: + # Sniffer is known to be brittle on small/dirty samples—silently fallback. + return None + + +def _heuristic_delimiter(lines): + """ + Fallback delimiter detection based on count consistency per line. + + Heuristic: + - For each candidate delimiter, count occurrences per line. + - Keep only lines with count > 0 (line must contain the delimiter). + - Require at least half of lines to contain the delimiter (min 2). + - Compute the mode (most common count). If the proportion of lines that + exhibit the modal count is >= 80%, accept that delimiter. + + Args: + lines (list[str]): Sample lines. + + Returns: + str | None: Best delimiter if one meets the consistency threshold; else None. + """ + candidates = [",", "\t", ";", "|"] + best = None + best_score = 0.0 + + for d in candidates: + # Count how many times the delimiter appears in each line. + counts = [ln.count(d) for ln in lines] + # Consider only lines that actually contain the delimiter at least once. + nonzero = [c for c in counts if c > 0] + + # Require that more than half of lines (and at least 2) contain the delimiter. + if len(nonzero) < max(2, int(0.5 * len(lines))): + continue + + # Find the modal count and its frequency. + cnt = Counter(nonzero) + pairs = cnt.most_common(1) + if not pairs: + continue + + mode, mode_freq = pairs[0] + # Consistency ratio: lines with the modal count / total lines in the sample. + consistency = mode_freq / len(lines) + # Accept if consistent enough and better than any previous candidate. + if mode >= 1 and consistency >= 0.80 and consistency > best_score: + best = d + best_score = consistency + + return best + + +def _lightweight_parse_check(lines, delimiter): + """ + Parse a few lines with csv.reader and check structural stability. + + Heuristic: + - Parse up to 5 lines with the given delimiter. + - Count column widths per parsed row. + - Require at least 2 non-empty rows. + - Allow at most 1 row whose width deviates by >2 columns from the first row. + + Args: + lines (list[str]): Sample lines (decoded). + delimiter (str): Delimiter chosen by sniffing/heuristics. + + Returns: + bool: True if parsing looks stable; False otherwise. + """ + try: + # csv.reader accepts any iterable of strings; feeding the first 10 lines is fine. + reader = csv.reader(lines[:10], delimiter=delimiter) + widths = [] + valid_rows = 0 + for row in reader: + if not row: + continue + + widths.append(len(row)) + valid_rows += 1 + + # Need at least two meaningful rows to make a judgment. + if valid_rows < 2: + return False + + if widths: + first = widths[0] + # Count rows whose width deviates significantly (>2) from the first row. + unstable = sum(1 for w in widths if abs(w - first) > 2) + # Permit at most 1 unstable row among the parsed sample. + return unstable <= 1 + return False + except Exception: + return False diff --git a/cognee/infrastructure/loaders/LoaderEngine.py b/cognee/infrastructure/loaders/LoaderEngine.py index 6b62f7641..37e63c9fc 100644 --- a/cognee/infrastructure/loaders/LoaderEngine.py +++ b/cognee/infrastructure/loaders/LoaderEngine.py @@ -30,6 +30,7 @@ class LoaderEngine: "pypdf_loader", "image_loader", "audio_loader", + "csv_loader", "unstructured_loader", "advanced_pdf_loader", ] diff --git a/cognee/infrastructure/loaders/core/__init__.py b/cognee/infrastructure/loaders/core/__init__.py index 8a2df80f9..09819fbd2 100644 --- a/cognee/infrastructure/loaders/core/__init__.py +++ b/cognee/infrastructure/loaders/core/__init__.py @@ -3,5 +3,6 @@ from .text_loader import TextLoader from .audio_loader import AudioLoader from .image_loader import ImageLoader +from .csv_loader import CsvLoader -__all__ = ["TextLoader", "AudioLoader", "ImageLoader"] +__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"] diff --git a/cognee/infrastructure/loaders/core/csv_loader.py b/cognee/infrastructure/loaders/core/csv_loader.py new file mode 100644 index 000000000..a314a7a24 --- /dev/null +++ b/cognee/infrastructure/loaders/core/csv_loader.py @@ -0,0 +1,93 @@ +import os +from typing import List +import csv +from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface +from cognee.infrastructure.files.storage import get_file_storage, get_storage_config +from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata + + +class CsvLoader(LoaderInterface): + """ + Core CSV file loader that handles basic CSV file formats. + """ + + @property + def supported_extensions(self) -> List[str]: + """Supported text file extensions.""" + return [ + "csv", + ] + + @property + def supported_mime_types(self) -> List[str]: + """Supported MIME types for text content.""" + return [ + "text/csv", + ] + + @property + def loader_name(self) -> str: + """Unique identifier for this loader.""" + return "csv_loader" + + def can_handle(self, extension: str, mime_type: str) -> bool: + """ + Check if this loader can handle the given file. + + Args: + extension: File extension + mime_type: Optional MIME type + + Returns: + True if file can be handled, False otherwise + """ + if extension in self.supported_extensions and mime_type in self.supported_mime_types: + return True + + return False + + async def load(self, file_path: str, encoding: str = "utf-8", **kwargs): + """ + Load and process the csv file. + + Args: + file_path: Path to the file to load + encoding: Text encoding to use (default: utf-8) + **kwargs: Additional configuration (unused) + + Returns: + LoaderResult containing the file content and metadata + + Raises: + FileNotFoundError: If file doesn't exist + UnicodeDecodeError: If file cannot be decoded with specified encoding + OSError: If file cannot be read + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as f: + file_metadata = await get_file_metadata(f) + # Name ingested file of current loader based on original file content hash + storage_file_name = "text_" + file_metadata["content_hash"] + ".txt" + + row_texts = [] + row_index = 1 + + with open(file_path, "r", encoding=encoding, newline="") as file: + reader = csv.DictReader(file) + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + row_texts.append(f"Row {row_index}:\n{row_text}\n") + row_index += 1 + + content = "\n".join(row_texts) + + storage_config = get_storage_config() + data_root_directory = storage_config["data_root_directory"] + storage = get_file_storage(data_root_directory) + + full_file_path = await storage.store(storage_file_name, content) + + return full_file_path diff --git a/cognee/infrastructure/loaders/core/text_loader.py b/cognee/infrastructure/loaders/core/text_loader.py index a6f94be9b..e478edb22 100644 --- a/cognee/infrastructure/loaders/core/text_loader.py +++ b/cognee/infrastructure/loaders/core/text_loader.py @@ -16,7 +16,7 @@ class TextLoader(LoaderInterface): @property def supported_extensions(self) -> List[str]: """Supported text file extensions.""" - return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"] + return ["txt", "md", "json", "xml", "yaml", "yml", "log"] @property def supported_mime_types(self) -> List[str]: @@ -24,7 +24,6 @@ class TextLoader(LoaderInterface): return [ "text/plain", "text/markdown", - "text/csv", "application/json", "text/xml", "application/xml", diff --git a/cognee/infrastructure/loaders/supported_loaders.py b/cognee/infrastructure/loaders/supported_loaders.py index d103babe3..b506df5f3 100644 --- a/cognee/infrastructure/loaders/supported_loaders.py +++ b/cognee/infrastructure/loaders/supported_loaders.py @@ -1,5 +1,5 @@ from cognee.infrastructure.loaders.external import PyPdfLoader -from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader +from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader # Registry for loader implementations supported_loaders = { @@ -7,6 +7,7 @@ supported_loaders = { TextLoader.loader_name: TextLoader, ImageLoader.loader_name: ImageLoader, AudioLoader.loader_name: AudioLoader, + CsvLoader.loader_name: CsvLoader, } # Try adding optional loaders diff --git a/cognee/modules/chunking/CsvChunker.py b/cognee/modules/chunking/CsvChunker.py new file mode 100644 index 000000000..4ba4a969e --- /dev/null +++ b/cognee/modules/chunking/CsvChunker.py @@ -0,0 +1,35 @@ +from cognee.shared.logging_utils import get_logger + + +from cognee.tasks.chunks import chunk_by_row +from cognee.modules.chunking.Chunker import Chunker +from .models.DocumentChunk import DocumentChunk + +logger = get_logger() + + +class CsvChunker(Chunker): + async def read(self): + async for content_text in self.get_text(): + if content_text is None: + continue + + for chunk_data in chunk_by_row(content_text, self.max_chunk_size): + if chunk_data["chunk_size"] <= self.max_chunk_size: + yield DocumentChunk( + id=chunk_data["chunk_id"], + text=chunk_data["text"], + chunk_size=chunk_data["chunk_size"], + is_part_of=self.document, + chunk_index=self.chunk_index, + cut_type=chunk_data["cut_type"], + contains=[], + metadata={ + "index_fields": ["text"], + }, + ) + self.chunk_index += 1 + else: + raise ValueError( + f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}" + ) diff --git a/cognee/modules/data/processing/document_types/CsvDocument.py b/cognee/modules/data/processing/document_types/CsvDocument.py new file mode 100644 index 000000000..3381275bd --- /dev/null +++ b/cognee/modules/data/processing/document_types/CsvDocument.py @@ -0,0 +1,33 @@ +import io +import csv +from typing import Type + +from cognee.modules.chunking.Chunker import Chunker +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from .Document import Document + + +class CsvDocument(Document): + type: str = "csv" + mime_type: str = "text/csv" + + async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int): + async def get_text(): + async with open_data_file( + self.raw_data_location, mode="r", encoding="utf-8", newline="" + ) as file: + content = file.read() + file_like_obj = io.StringIO(content) + reader = csv.DictReader(file_like_obj) + + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + if not row_text.strip(): + break + yield row_text + + chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text) + + async for chunk in chunker.read(): + yield chunk diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py index 2e862f4ba..133dd53f8 100644 --- a/cognee/modules/data/processing/document_types/__init__.py +++ b/cognee/modules/data/processing/document_types/__init__.py @@ -4,3 +4,4 @@ from .TextDocument import TextDocument from .ImageDocument import ImageDocument from .AudioDocument import AudioDocument from .UnstructuredDocument import UnstructuredDocument +from .CsvDocument import CsvDocument diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py index 22ce96be8..37d4de73e 100644 --- a/cognee/tasks/chunks/__init__.py +++ b/cognee/tasks/chunks/__init__.py @@ -1,4 +1,5 @@ from .chunk_by_word import chunk_by_word from .chunk_by_sentence import chunk_by_sentence from .chunk_by_paragraph import chunk_by_paragraph +from .chunk_by_row import chunk_by_row from .remove_disconnected_chunks import remove_disconnected_chunks diff --git a/cognee/tasks/chunks/chunk_by_row.py b/cognee/tasks/chunks/chunk_by_row.py new file mode 100644 index 000000000..8daf13689 --- /dev/null +++ b/cognee/tasks/chunks/chunk_by_row.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, Iterator +from uuid import NAMESPACE_OID, uuid5 + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine + + +def _get_pair_size(pair_text: str) -> int: + """ + Calculate the size of a given text in terms of tokens. + + If an embedding engine's tokenizer is available, count the tokens for the provided word. + If the tokenizer is not available, assume the word counts as one token. + + Parameters: + ----------- + + - pair_text (str): The key:value pair text for which the token size is to be calculated. + + Returns: + -------- + + - int: The number of tokens representing the text, typically an integer, depending + on the tokenizer's output. + """ + embedding_engine = get_embedding_engine() + if embedding_engine.tokenizer: + return embedding_engine.tokenizer.count_tokens(pair_text) + else: + return 3 + + +def chunk_by_row( + data: str, + max_chunk_size, +) -> Iterator[Dict[str, Any]]: + """ + Chunk the input text by row while enabling exact text reconstruction. + + This function divides the given text data into smaller chunks on a line-by-line basis, + ensuring that the size of each chunk is less than or equal to the specified maximum + chunk size. It guarantees that when the generated chunks are concatenated, they + reproduce the original text accurately. The tokenization process is handled by + adapters compatible with the vector engine's embedding model. + + Parameters: + ----------- + + - data (str): The input text to be chunked. + - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or + words. + """ + current_chunk_list = [] + chunk_index = 0 + current_chunk_size = 0 + + lines = data.split("\n\n") + for line in lines: + pairs_text = line.split(", ") + + for pair_text in pairs_text: + pair_size = _get_pair_size(pair_text) + if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size): + # Yield current cut chunk + current_chunk = ", ".join(current_chunk_list) + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_cut", + } + + yield chunk_dict + + # Start new chunk with current pair text + current_chunk_list = [] + current_chunk_size = 0 + chunk_index += 1 + + current_chunk_list.append(pair_text) + current_chunk_size += pair_size + + # Yield row chunk + current_chunk = ", ".join(current_chunk_list) + if current_chunk: + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_end", + } + + yield chunk_dict diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py index 9fa512906..e4f13ebd1 100644 --- a/cognee/tasks/documents/classify_documents.py +++ b/cognee/tasks/documents/classify_documents.py @@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import ( ImageDocument, TextDocument, UnstructuredDocument, + CsvDocument, ) from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.utils.generate_node_id import generate_node_id @@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError EXTENSION_TO_DOCUMENT_CLASS = { "pdf": PdfDocument, # Text documents "txt": TextDocument, + "csv": CsvDocument, "docx": UnstructuredDocument, "doc": UnstructuredDocument, "odt": UnstructuredDocument, diff --git a/cognee/tests/integration/documents/CsvDocument_test.py b/cognee/tests/integration/documents/CsvDocument_test.py new file mode 100644 index 000000000..421bb81bd --- /dev/null +++ b/cognee/tests/integration/documents/CsvDocument_test.py @@ -0,0 +1,70 @@ +import os +import sys +import uuid +import pytest +import pathlib +from unittest.mock import patch + +from cognee.modules.chunking.CsvChunker import CsvChunker +from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument +from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine +from cognee.tests.integration.documents.async_gen_zip import async_gen_zip + +chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row") + + +GROUND_TRUTH = { + "chunk_size_10": [ + {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0}, + {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1}, + {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2}, + {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3}, + ], + "chunk_size_128": [ + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0}, + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1}, + ], +} + + +@pytest.mark.parametrize( + "input_file,chunk_size", + [("example_with_header.csv", 10), ("example_with_header.csv", 128)], +) +@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine) +@pytest.mark.asyncio +async def test_CsvDocument(mock_engine, input_file, chunk_size): + # Define file paths of test data + csv_file_path = os.path.join( + pathlib.Path(__file__).parent.parent.parent, + "test_data", + input_file, + ) + + # Define test documents + csv_document = CsvDocument( + id=uuid.uuid4(), + name="example_with_header.csv", + raw_data_location=csv_file_path, + external_metadata="", + mime_type="text/csv", + ) + + # TEST CSV + ground_truth_key = f"chunk_size_{chunk_size}" + async for ground_truth, row_data in async_gen_zip( + GROUND_TRUTH[ground_truth_key], + csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size), + ): + assert ground_truth["token_count"] == row_data.chunk_size, ( + f'{ground_truth["token_count"] = } != {row_data.chunk_size = }' + ) + assert ground_truth["len_text"] == len(row_data.text), ( + f'{ground_truth["len_text"] = } != {len(row_data.text) = }' + ) + assert ground_truth["cut_type"] == row_data.cut_type, ( + f'{ground_truth["cut_type"] = } != {row_data.cut_type = }' + ) + assert ground_truth["chunk_index"] == row_data.chunk_index, ( + f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }' + ) diff --git a/cognee/tests/test_data/example_with_header.csv b/cognee/tests/test_data/example_with_header.csv new file mode 100644 index 000000000..dc900e5ef --- /dev/null +++ b/cognee/tests/test_data/example_with_header.csv @@ -0,0 +1,3 @@ +id,name,age,city,country +1,Eric,30,Beijing,China +2,Joe,35,Berlin,Germany diff --git a/cognee/tests/unit/processing/chunks/chunk_by_row_test.py b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py new file mode 100644 index 000000000..7d6a73a06 --- /dev/null +++ b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py @@ -0,0 +1,52 @@ +from itertools import product + +import numpy as np +import pytest + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine +from cognee.tasks.chunks import chunk_by_row + +INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA" +max_chunk_size_vals = [8, 32] + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_isomorphism(input_text, max_chunk_size): + chunks = chunk_by_row(input_text, max_chunk_size) + reconstructed_text = ", ".join([chunk["text"] for chunk in chunks]) + assert reconstructed_text == input_text, ( + f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_row_chunk_length(input_text, max_chunk_size): + chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)) + embedding_engine = get_embedding_engine() + + chunk_lengths = np.array( + [embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks] + ) + + larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size] + assert np.all(chunk_lengths <= max_chunk_size), ( + f"{max_chunk_size = }: {larger_chunks} are too large" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size): + chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size) + chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) + assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( + f"{chunk_indices = } are not monotonically increasing" + )