diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 2faa9903e..50a3e081d 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -25,6 +25,7 @@ from cognee.tasks.documents import ( from cognee.tasks.graph import extract_graph_from_data from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_text +from cognee.modules.chunking.TextChunker import TextChunker logger = logging.getLogger("cognify.v2") @@ -123,7 +124,9 @@ async def get_default_tasks( Task(classify_documents), Task(check_permissions_on_documents, user=user, permissions=["write"]), Task( - extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens() + extract_chunks_from_documents, + max_chunk_tokens=get_max_chunk_tokens(), + chunker=TextChunker, ), # Extract text chunks based on the document type. Task( extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index 75152fd3d..d6c7faeb5 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -1,8 +1,5 @@ -from typing import Optional - from cognee.infrastructure.llm.get_llm_client import get_llm_client -from .ChunkerMapping import ChunkerConfig from .Document import Document @@ -13,13 +10,12 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return result.text - def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): + def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int): # Transcribe the audio file text = self.create_transcript() - chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func( + chunker = chunker_cls( self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens ) diff --git a/cognee/modules/data/processing/document_types/ChunkerMapping.py b/cognee/modules/data/processing/document_types/ChunkerMapping.py deleted file mode 100644 index f9a251528..000000000 --- a/cognee/modules/data/processing/document_types/ChunkerMapping.py +++ /dev/null @@ -1,14 +0,0 @@ -from cognee.modules.chunking.TextChunker import TextChunker - - -class ChunkerConfig: - chunker_mapping = {"text_chunker": TextChunker} - - @classmethod - def get_chunker(cls, chunker_name: str): - chunker_class = cls.chunker_mapping.get(chunker_name) - if chunker_class is None: - raise NotImplementedError( - f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}" - ) - return chunker_class diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index b378abd44..08ee11fbf 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -9,5 +9,7 @@ class Document(DataPoint): mime_type: str metadata: dict = {"index_fields": ["name"]} - def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str: + def read( + self, chunk_size: int, chunker_cls: type, max_chunk_tokens: Optional[int] = None + ) -> str: pass diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index 5f4cb287c..d628ba2cd 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -1,8 +1,5 @@ -from typing import Optional - from cognee.infrastructure.llm.get_llm_client import get_llm_client -from .ChunkerMapping import ChunkerConfig from .Document import Document @@ -13,12 +10,11 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return result.choices[0].message.content - def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): + def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int): # Transcribe the image file text = self.transcribe_image() - chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func( + chunker = chunker_cls( self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens ) diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 8273e0177..28d40d9c7 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -1,15 +1,12 @@ -from typing import Optional - from pypdf import PdfReader -from .ChunkerMapping import ChunkerConfig from .Document import Document class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): + def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int): file = PdfReader(self.raw_data_location) def get_text(): @@ -17,8 +14,7 @@ class PdfDocument(Document): page_text = page.extract_text() yield page_text - chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func( + chunker = chunker_cls( self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens ) diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index 2ab3d9185..140811ca3 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -1,11 +1,10 @@ -from .ChunkerMapping import ChunkerConfig from .Document import Document class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): + def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int): def get_text(): with open(self.raw_data_location, mode="r", encoding="utf-8") as file: while True: @@ -16,9 +15,7 @@ class TextDocument(Document): yield text - chunker_func = ChunkerConfig.get_chunker(chunker) - - chunker = chunker_func( + chunker = chunker_cls( self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens ) diff --git a/cognee/modules/data/processing/document_types/UnstructuredDocument.py b/cognee/modules/data/processing/document_types/UnstructuredDocument.py index 254958d14..0cac46f8a 100644 --- a/cognee/modules/data/processing/document_types/UnstructuredDocument.py +++ b/cognee/modules/data/processing/document_types/UnstructuredDocument.py @@ -1,5 +1,4 @@ from io import StringIO -from typing import Optional from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.data.exceptions import UnstructuredLibraryImportError @@ -10,7 +9,7 @@ from .Document import Document class UnstructuredDocument(Document): type: str = "unstructured" - def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str: + def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int) -> str: def get_text(): try: from unstructured.partition.auto import partition @@ -29,7 +28,7 @@ class UnstructuredDocument(Document): yield text - chunker = TextChunker( + chunker = chunker_cls( self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens ) diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index a65f32fc9..68d1a0887 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -5,6 +5,7 @@ from sqlalchemy import select from cognee.modules.data.models import Data from cognee.infrastructure.databases.relational import get_relational_engine from uuid import UUID +from cognee.modules.chunking.TextChunker import TextChunker async def update_document_token_count(document_id: UUID, token_count: int) -> None: @@ -26,7 +27,7 @@ async def extract_chunks_from_documents( documents: list[Document], max_chunk_tokens: int, chunk_size: int = 1024, - chunker="text_chunker", + chunker=TextChunker, ) -> AsyncGenerator: """ Extracts chunks of data from a list of documents based on the specified chunking parameters. @@ -38,7 +39,7 @@ async def extract_chunks_from_documents( for document in documents: document_token_count = 0 for document_chunk in document.read( - chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens + chunk_size=chunk_size, chunker_cls=chunker, max_chunk_tokens=max_chunk_tokens ): document_token_count += document_chunk.token_count yield document_chunk diff --git a/cognee/tests/integration/documents/AudioDocument_test.py b/cognee/tests/integration/documents/AudioDocument_test.py index 38b547140..3d9ba532e 100644 --- a/cognee/tests/integration/documents/AudioDocument_test.py +++ b/cognee/tests/integration/documents/AudioDocument_test.py @@ -1,6 +1,6 @@ import uuid from unittest.mock import patch - +from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument GROUND_TRUTH = [ @@ -34,7 +34,8 @@ def test_AudioDocument(): ) with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) + GROUND_TRUTH, + document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512), ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/ImageDocument_test.py b/cognee/tests/integration/documents/ImageDocument_test.py index faa54fa27..dad9e8ab9 100644 --- a/cognee/tests/integration/documents/ImageDocument_test.py +++ b/cognee/tests/integration/documents/ImageDocument_test.py @@ -1,6 +1,6 @@ import uuid from unittest.mock import patch - +from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument GROUND_TRUTH = [ @@ -23,7 +23,8 @@ def test_ImageDocument(): ) with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) + GROUND_TRUTH, + document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512), ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/PdfDocument_test.py b/cognee/tests/integration/documents/PdfDocument_test.py index e9530fc12..6abe2a858 100644 --- a/cognee/tests/integration/documents/PdfDocument_test.py +++ b/cognee/tests/integration/documents/PdfDocument_test.py @@ -1,6 +1,6 @@ import os import uuid - +from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument GROUND_TRUTH = [ @@ -25,7 +25,7 @@ def test_PdfDocument(): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048) + GROUND_TRUTH, document.read(chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=2048) ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index 99e28a3ac..7c1af10b5 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -2,7 +2,7 @@ import os import uuid import pytest - +from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.data.processing.document_types.TextDocument import TextDocument GROUND_TRUTH = { @@ -38,7 +38,7 @@ def test_TextDocument(input_file, chunk_size): for ground_truth, paragraph_data in zip( GROUND_TRUTH[input_file], - document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024), + document.read(chunk_size=chunk_size, chunker_cls=TextChunker, max_chunk_tokens=1024), ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/UnstructuredDocument_test.py b/cognee/tests/integration/documents/UnstructuredDocument_test.py index d76843c0a..1331ee8ef 100644 --- a/cognee/tests/integration/documents/UnstructuredDocument_test.py +++ b/cognee/tests/integration/documents/UnstructuredDocument_test.py @@ -1,6 +1,6 @@ import os import uuid - +from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument @@ -69,7 +69,7 @@ def test_UnstructuredDocument(): # Test PPTX for paragraph_data in pptx_document.read( - chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024 ): assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" @@ -79,7 +79,7 @@ def test_UnstructuredDocument(): # Test DOCX for paragraph_data in docx_document.read( - chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024 ): assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" @@ -89,7 +89,7 @@ def test_UnstructuredDocument(): # TEST CSV for paragraph_data in csv_document.read( - chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024 ): assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( @@ -101,7 +101,7 @@ def test_UnstructuredDocument(): # Test XLSX for paragraph_data in xlsx_document.read( - chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024 ): assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"