feat: csv ingestion loader & chunk
Signed-off-by: EricXiao <taoiaox@gmail.com>
This commit is contained in:
parent
62157a114d
commit
742866b4c9
18 changed files with 623 additions and 6 deletions
|
|
@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
|
|
@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
chunker_class = LangchainChunker
|
||||
except ImportError:
|
||||
fmt.warning("LangchainChunker not available, using TextChunker")
|
||||
elif args.chunker == "CsvChunker":
|
||||
try:
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
|
||||
chunker_class = CsvChunker
|
||||
except ImportError:
|
||||
fmt.warning("CsvChunker not available, using TextChunker")
|
||||
|
||||
result = await cognee.cognify(
|
||||
datasets=datasets,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
|
|||
]
|
||||
|
||||
# Chunker choices
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
|
||||
|
||||
# Output format choices
|
||||
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from typing import BinaryIO
|
||||
import filetype
|
||||
|
||||
from .is_text_content import is_text_content
|
||||
from .is_csv_content import is_csv_content
|
||||
|
||||
|
||||
class FileTypeException(Exception):
|
||||
|
|
@ -134,3 +136,44 @@ def guess_file_type(file: BinaryIO) -> filetype.Type:
|
|||
raise FileTypeException(f"Unknown file detected: {file.name}.")
|
||||
|
||||
return file_type
|
||||
|
||||
|
||||
class CsvFileType(filetype.Type):
|
||||
"""
|
||||
Match CSV file types based on MIME type and extension.
|
||||
|
||||
Public methods:
|
||||
- match
|
||||
|
||||
Instance variables:
|
||||
- MIME: The MIME type of the CSV.
|
||||
- EXTENSION: The file extension of the CSV.
|
||||
"""
|
||||
|
||||
MIME = "text/csv"
|
||||
EXTENSION = "csv"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(mime=self.MIME, extension=self.EXTENSION)
|
||||
|
||||
def match(self, buf):
|
||||
"""
|
||||
Determine if the given buffer contains csv content.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- buf: The buffer to check for csv content.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns True if the buffer is identified as csv content, otherwise False.
|
||||
"""
|
||||
|
||||
return is_csv_content(buf)
|
||||
|
||||
|
||||
csv_file_type = CsvFileType()
|
||||
|
||||
filetype.add_type(csv_file_type)
|
||||
|
|
|
|||
181
cognee/infrastructure/files/utils/is_csv_content.py
Normal file
181
cognee/infrastructure/files/utils/is_csv_content.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import csv
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def is_csv_content(content):
|
||||
"""
|
||||
Heuristically determine whether a bytes-like object is CSV text.
|
||||
|
||||
Strategy (fail-fast and cheap to expensive):
|
||||
1) Decode: Try a small ordered list of common encodings with strict errors.
|
||||
2) Line sampling: require >= 2 non-empty lines; sample up to 50 lines.
|
||||
3) Delimiter detection:
|
||||
- Prefer csv.Sniffer() with common delimiters.
|
||||
- Fallback to a lightweight consistency heuristic.
|
||||
4) Lightweight parse check:
|
||||
- Parse a few lines with the delimiter.
|
||||
- Ensure at least 2 valid rows and relatively stable column counts.
|
||||
|
||||
Returns:
|
||||
bool: True if the buffer looks like CSV; False otherwise.
|
||||
"""
|
||||
try:
|
||||
encoding_list = [
|
||||
"utf-8",
|
||||
"utf-8-sig",
|
||||
"utf-32-le",
|
||||
"utf-32-be",
|
||||
"utf-16-le",
|
||||
"utf-16-be",
|
||||
"gb18030",
|
||||
"shift_jis",
|
||||
"cp949",
|
||||
"cp1252",
|
||||
"iso-8859-1",
|
||||
]
|
||||
|
||||
# Try to decode strictly—if decoding fails for all encodings, it's not text/CSV.
|
||||
text = None
|
||||
for enc in encoding_list:
|
||||
try:
|
||||
text = content.decode(enc, errors="strict")
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
if text is None:
|
||||
return False
|
||||
|
||||
# Reject empty/whitespace-only payloads.
|
||||
stripped = text.strip()
|
||||
if not stripped:
|
||||
return False
|
||||
|
||||
# Split into logical lines and drop empty ones. Require at least two lines.
|
||||
lines = [ln for ln in text.splitlines() if ln.strip()]
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
# Take a small sample to keep sniffing cheap and predictable.
|
||||
sample_lines = lines[:50]
|
||||
|
||||
# Detect delimiter using csv.Sniffer first; if that fails, use our heuristic.
|
||||
delimiter = _sniff_delimiter(sample_lines) or _heuristic_delimiter(sample_lines)
|
||||
if not delimiter:
|
||||
return False
|
||||
|
||||
# Finally, do a lightweight parse sanity check with the chosen delimiter.
|
||||
return _lightweight_parse_check(sample_lines, delimiter)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _sniff_delimiter(lines):
|
||||
"""
|
||||
Try Python's built-in csv.Sniffer on a sample.
|
||||
|
||||
Args:
|
||||
lines (list[str]): Sample lines (already decoded).
|
||||
|
||||
Returns:
|
||||
str | None: The detected delimiter if sniffing succeeds; otherwise None.
|
||||
"""
|
||||
# Join up to 50 lines to form the sample string Sniffer will inspect.
|
||||
sample = "\n".join(lines[:50])
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(sample, delimiters=",\t;|")
|
||||
return dialect.delimiter
|
||||
except Exception:
|
||||
# Sniffer is known to be brittle on small/dirty samples—silently fallback.
|
||||
return None
|
||||
|
||||
|
||||
def _heuristic_delimiter(lines):
|
||||
"""
|
||||
Fallback delimiter detection based on count consistency per line.
|
||||
|
||||
Heuristic:
|
||||
- For each candidate delimiter, count occurrences per line.
|
||||
- Keep only lines with count > 0 (line must contain the delimiter).
|
||||
- Require at least half of lines to contain the delimiter (min 2).
|
||||
- Compute the mode (most common count). If the proportion of lines that
|
||||
exhibit the modal count is >= 80%, accept that delimiter.
|
||||
|
||||
Args:
|
||||
lines (list[str]): Sample lines.
|
||||
|
||||
Returns:
|
||||
str | None: Best delimiter if one meets the consistency threshold; else None.
|
||||
"""
|
||||
candidates = [",", "\t", ";", "|"]
|
||||
best = None
|
||||
best_score = 0.0
|
||||
|
||||
for d in candidates:
|
||||
# Count how many times the delimiter appears in each line.
|
||||
counts = [ln.count(d) for ln in lines]
|
||||
# Consider only lines that actually contain the delimiter at least once.
|
||||
nonzero = [c for c in counts if c > 0]
|
||||
|
||||
# Require that more than half of lines (and at least 2) contain the delimiter.
|
||||
if len(nonzero) < max(2, int(0.5 * len(lines))):
|
||||
continue
|
||||
|
||||
# Find the modal count and its frequency.
|
||||
cnt = Counter(nonzero)
|
||||
pairs = cnt.most_common(1)
|
||||
if not pairs:
|
||||
continue
|
||||
|
||||
mode, mode_freq = pairs[0]
|
||||
# Consistency ratio: lines with the modal count / total lines in the sample.
|
||||
consistency = mode_freq / len(lines)
|
||||
# Accept if consistent enough and better than any previous candidate.
|
||||
if mode >= 1 and consistency >= 0.80 and consistency > best_score:
|
||||
best = d
|
||||
best_score = consistency
|
||||
|
||||
return best
|
||||
|
||||
|
||||
def _lightweight_parse_check(lines, delimiter):
|
||||
"""
|
||||
Parse a few lines with csv.reader and check structural stability.
|
||||
|
||||
Heuristic:
|
||||
- Parse up to 5 lines with the given delimiter.
|
||||
- Count column widths per parsed row.
|
||||
- Require at least 2 non-empty rows.
|
||||
- Allow at most 1 row whose width deviates by >2 columns from the first row.
|
||||
|
||||
Args:
|
||||
lines (list[str]): Sample lines (decoded).
|
||||
delimiter (str): Delimiter chosen by sniffing/heuristics.
|
||||
|
||||
Returns:
|
||||
bool: True if parsing looks stable; False otherwise.
|
||||
"""
|
||||
try:
|
||||
# csv.reader accepts any iterable of strings; feeding the first 10 lines is fine.
|
||||
reader = csv.reader(lines[:10], delimiter=delimiter)
|
||||
widths = []
|
||||
valid_rows = 0
|
||||
for row in reader:
|
||||
if not row:
|
||||
continue
|
||||
|
||||
widths.append(len(row))
|
||||
valid_rows += 1
|
||||
|
||||
# Need at least two meaningful rows to make a judgment.
|
||||
if valid_rows < 2:
|
||||
return False
|
||||
|
||||
if widths:
|
||||
first = widths[0]
|
||||
# Count rows whose width deviates significantly (>2) from the first row.
|
||||
unstable = sum(1 for w in widths if abs(w - first) > 2)
|
||||
# Permit at most 1 unstable row among the parsed sample.
|
||||
return unstable <= 1
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -30,6 +30,7 @@ class LoaderEngine:
|
|||
"pypdf_loader",
|
||||
"image_loader",
|
||||
"audio_loader",
|
||||
"csv_loader",
|
||||
"unstructured_loader",
|
||||
"advanced_pdf_loader",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@
|
|||
from .text_loader import TextLoader
|
||||
from .audio_loader import AudioLoader
|
||||
from .image_loader import ImageLoader
|
||||
from .csv_loader import CsvLoader
|
||||
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]
|
||||
|
|
|
|||
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
from typing import List
|
||||
import csv
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
|
||||
|
||||
class CsvLoader(LoaderInterface):
|
||||
"""
|
||||
Core CSV file loader that handles basic CSV file formats.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return [
|
||||
"csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""Supported MIME types for text content."""
|
||||
return [
|
||||
"text/csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
"""Unique identifier for this loader."""
|
||||
return "csv_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
extension: File extension
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
True if file can be handled, False otherwise
|
||||
"""
|
||||
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
|
||||
"""
|
||||
Load and process the csv file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load
|
||||
encoding: Text encoding to use (default: utf-8)
|
||||
**kwargs: Additional configuration (unused)
|
||||
|
||||
Returns:
|
||||
LoaderResult containing the file content and metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||
OSError: If file cannot be read
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
# Name ingested file of current loader based on original file content hash
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
row_texts = []
|
||||
row_index = 1
|
||||
|
||||
with open(file_path, "r", encoding=encoding, newline="") as file:
|
||||
reader = csv.DictReader(file)
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
row_texts.append(f"Row {row_index}:\n{row_text}\n")
|
||||
row_index += 1
|
||||
|
||||
content = "\n".join(row_texts)
|
||||
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, content)
|
||||
|
||||
return full_file_path
|
||||
|
|
@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
|
|||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
|
||||
return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
|
|
@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
|
|||
return [
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"text/xml",
|
||||
"application/xml",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.infrastructure.loaders.external import PyPdfLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
|
||||
|
||||
# Registry for loader implementations
|
||||
supported_loaders = {
|
||||
|
|
@ -7,6 +7,7 @@ supported_loaders = {
|
|||
TextLoader.loader_name: TextLoader,
|
||||
ImageLoader.loader_name: ImageLoader,
|
||||
AudioLoader.loader_name: AudioLoader,
|
||||
CsvLoader.loader_name: CsvLoader,
|
||||
}
|
||||
|
||||
# Try adding optional loaders
|
||||
|
|
|
|||
35
cognee/modules/chunking/CsvChunker.py
Normal file
35
cognee/modules/chunking/CsvChunker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_row
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CsvChunker(Chunker):
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
if content_text is None:
|
||||
continue
|
||||
|
||||
for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
|
||||
if chunk_data["chunk_size"] <= self.max_chunk_size:
|
||||
yield DocumentChunk(
|
||||
id=chunk_data["chunk_id"],
|
||||
text=chunk_data["text"],
|
||||
chunk_size=chunk_data["chunk_size"],
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
contains=[],
|
||||
metadata={
|
||||
"index_fields": ["text"],
|
||||
},
|
||||
)
|
||||
self.chunk_index += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
|
||||
)
|
||||
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import io
|
||||
import csv
|
||||
from typing import Type
|
||||
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class CsvDocument(Document):
|
||||
type: str = "csv"
|
||||
mime_type: str = "text/csv"
|
||||
|
||||
async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
|
||||
async def get_text():
|
||||
async with open_data_file(
|
||||
self.raw_data_location, mode="r", encoding="utf-8", newline=""
|
||||
) as file:
|
||||
content = file.read()
|
||||
file_like_obj = io.StringIO(content)
|
||||
reader = csv.DictReader(file_like_obj)
|
||||
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
if not row_text.strip():
|
||||
break
|
||||
yield row_text
|
||||
|
||||
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
|
||||
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
@ -4,3 +4,4 @@ from .TextDocument import TextDocument
|
|||
from .ImageDocument import ImageDocument
|
||||
from .AudioDocument import AudioDocument
|
||||
from .UnstructuredDocument import UnstructuredDocument
|
||||
from .CsvDocument import CsvDocument
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .chunk_by_word import chunk_by_word
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
from .chunk_by_paragraph import chunk_by_paragraph
|
||||
from .chunk_by_row import chunk_by_row
|
||||
from .remove_disconnected_chunks import remove_disconnected_chunks
|
||||
|
|
|
|||
94
cognee/tasks/chunks/chunk_by_row.py
Normal file
94
cognee/tasks/chunks/chunk_by_row.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
from typing import Any, Dict, Iterator
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
|
||||
|
||||
def _get_pair_size(pair_text: str) -> int:
|
||||
"""
|
||||
Calculate the size of a given text in terms of tokens.
|
||||
|
||||
If an embedding engine's tokenizer is available, count the tokens for the provided word.
|
||||
If the tokenizer is not available, assume the word counts as one token.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- pair_text (str): The key:value pair text for which the token size is to be calculated.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- int: The number of tokens representing the text, typically an integer, depending
|
||||
on the tokenizer's output.
|
||||
"""
|
||||
embedding_engine = get_embedding_engine()
|
||||
if embedding_engine.tokenizer:
|
||||
return embedding_engine.tokenizer.count_tokens(pair_text)
|
||||
else:
|
||||
return 3
|
||||
|
||||
|
||||
def chunk_by_row(
|
||||
data: str,
|
||||
max_chunk_size,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Chunk the input text by row while enabling exact text reconstruction.
|
||||
|
||||
This function divides the given text data into smaller chunks on a line-by-line basis,
|
||||
ensuring that the size of each chunk is less than or equal to the specified maximum
|
||||
chunk size. It guarantees that when the generated chunks are concatenated, they
|
||||
reproduce the original text accurately. The tokenization process is handled by
|
||||
adapters compatible with the vector engine's embedding model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- data (str): The input text to be chunked.
|
||||
- max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or
|
||||
words.
|
||||
"""
|
||||
current_chunk_list = []
|
||||
chunk_index = 0
|
||||
current_chunk_size = 0
|
||||
|
||||
lines = data.split("\n\n")
|
||||
for line in lines:
|
||||
pairs_text = line.split(", ")
|
||||
|
||||
for pair_text in pairs_text:
|
||||
pair_size = _get_pair_size(pair_text)
|
||||
if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size):
|
||||
# Yield current cut chunk
|
||||
current_chunk = ", ".join(current_chunk_list)
|
||||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"chunk_size": current_chunk_size,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"chunk_index": chunk_index,
|
||||
"cut_type": "row_cut",
|
||||
}
|
||||
|
||||
yield chunk_dict
|
||||
|
||||
# Start new chunk with current pair text
|
||||
current_chunk_list = []
|
||||
current_chunk_size = 0
|
||||
chunk_index += 1
|
||||
|
||||
current_chunk_list.append(pair_text)
|
||||
current_chunk_size += pair_size
|
||||
|
||||
# Yield row chunk
|
||||
current_chunk = ", ".join(current_chunk_list)
|
||||
if current_chunk:
|
||||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"chunk_size": current_chunk_size,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"chunk_index": chunk_index,
|
||||
"cut_type": "row_end",
|
||||
}
|
||||
|
||||
yield chunk_dict
|
||||
|
|
@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import (
|
|||
ImageDocument,
|
||||
TextDocument,
|
||||
UnstructuredDocument,
|
||||
CsvDocument,
|
||||
)
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.engine.utils.generate_node_id import generate_node_id
|
||||
|
|
@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError
|
|||
EXTENSION_TO_DOCUMENT_CLASS = {
|
||||
"pdf": PdfDocument, # Text documents
|
||||
"txt": TextDocument,
|
||||
"csv": CsvDocument,
|
||||
"docx": UnstructuredDocument,
|
||||
"doc": UnstructuredDocument,
|
||||
"odt": UnstructuredDocument,
|
||||
|
|
|
|||
70
cognee/tests/integration/documents/CsvDocument_test.py
Normal file
70
cognee/tests/integration/documents/CsvDocument_test.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
import pathlib
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument
|
||||
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
|
||||
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
|
||||
|
||||
chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row")
|
||||
|
||||
|
||||
GROUND_TRUTH = {
|
||||
"chunk_size_10": [
|
||||
{"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0},
|
||||
{"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1},
|
||||
{"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2},
|
||||
{"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3},
|
||||
],
|
||||
"chunk_size_128": [
|
||||
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0},
|
||||
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_file,chunk_size",
|
||||
[("example_with_header.csv", 10), ("example_with_header.csv", 128)],
|
||||
)
|
||||
@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine)
|
||||
@pytest.mark.asyncio
|
||||
async def test_CsvDocument(mock_engine, input_file, chunk_size):
|
||||
# Define file paths of test data
|
||||
csv_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent.parent.parent,
|
||||
"test_data",
|
||||
input_file,
|
||||
)
|
||||
|
||||
# Define test documents
|
||||
csv_document = CsvDocument(
|
||||
id=uuid.uuid4(),
|
||||
name="example_with_header.csv",
|
||||
raw_data_location=csv_file_path,
|
||||
external_metadata="",
|
||||
mime_type="text/csv",
|
||||
)
|
||||
|
||||
# TEST CSV
|
||||
ground_truth_key = f"chunk_size_{chunk_size}"
|
||||
async for ground_truth, row_data in async_gen_zip(
|
||||
GROUND_TRUTH[ground_truth_key],
|
||||
csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size),
|
||||
):
|
||||
assert ground_truth["token_count"] == row_data.chunk_size, (
|
||||
f'{ground_truth["token_count"] = } != {row_data.chunk_size = }'
|
||||
)
|
||||
assert ground_truth["len_text"] == len(row_data.text), (
|
||||
f'{ground_truth["len_text"] = } != {len(row_data.text) = }'
|
||||
)
|
||||
assert ground_truth["cut_type"] == row_data.cut_type, (
|
||||
f'{ground_truth["cut_type"] = } != {row_data.cut_type = }'
|
||||
)
|
||||
assert ground_truth["chunk_index"] == row_data.chunk_index, (
|
||||
f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }'
|
||||
)
|
||||
3
cognee/tests/test_data/example_with_header.csv
Normal file
3
cognee/tests/test_data/example_with_header.csv
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
id,name,age,city,country
|
||||
1,Eric,30,Beijing,China
|
||||
2,Joe,35,Berlin,Germany
|
||||
|
52
cognee/tests/unit/processing/chunks/chunk_by_row_test.py
Normal file
52
cognee/tests/unit/processing/chunks/chunk_by_row_test.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.tasks.chunks import chunk_by_row
|
||||
|
||||
INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA"
|
||||
max_chunk_size_vals = [8, 32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_chunk_by_row_isomorphism(input_text, max_chunk_size):
|
||||
chunks = chunk_by_row(input_text, max_chunk_size)
|
||||
reconstructed_text = ", ".join([chunk["text"] for chunk in chunks])
|
||||
assert reconstructed_text == input_text, (
|
||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_row_chunk_length(input_text, max_chunk_size):
|
||||
chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size))
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
chunk_lengths = np.array(
|
||||
[embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks]
|
||||
)
|
||||
|
||||
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
||||
assert np.all(chunk_lengths <= max_chunk_size), (
|
||||
f"{max_chunk_size = }: {larger_chunks} are too large"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size):
|
||||
chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)
|
||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
||||
f"{chunk_indices = } are not monotonically increasing"
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue