feat: outsources chunking parameters to extract chunk from documents … (#289)

* feat: outsources chunking parameters to extract chunk from documents task
This commit is contained in:
hajdul88 2024-12-17 11:31:31 +01:00 committed by GitHub
parent bfa0f06fb4
commit 9e7ab6492a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 40 additions and 20 deletions

View file

@ -1,6 +1,6 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from .ChunkerMapping import ChunkerConfig
class AudioDocument(Document):
type: str = "audio"
@ -9,11 +9,12 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location)
return(result.text)
def read(self, chunk_size: int):
def read(self, chunk_size: int, chunker: str):
# Transcribe the audio file
text = self.create_transcript()
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
yield from chunker.read()

View file

@ -0,0 +1,15 @@
from cognee.modules.chunking.TextChunker import TextChunker
class ChunkerConfig:
chunker_mapping = {
"text_chunker": TextChunker
}
@classmethod
def get_chunker(cls, chunker_name: str):
chunker_class = cls.chunker_mapping.get(chunker_name)
if chunker_class is None:
raise NotImplementedError(
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
)
return chunker_class

View file

@ -13,5 +13,5 @@ class Document(DataPoint):
"type": "Document"
}
def read(self, chunk_size: int) -> str:
pass
def read(self, chunk_size: int, chunker = str) -> str:
pass

View file

@ -1,6 +1,6 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from .ChunkerMapping import ChunkerConfig
class ImageDocument(Document):
type: str = "image"
@ -10,10 +10,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location)
return(result.choices[0].message.content)
def read(self, chunk_size: int):
def read(self, chunk_size: int, chunker: str):
# Transcribe the image file
text = self.transcribe_image()
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
yield from chunker.read()

View file

@ -1,11 +1,11 @@
from pypdf import PdfReader
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from .ChunkerMapping import ChunkerConfig
class PdfDocument(Document):
type: str = "pdf"
def read(self, chunk_size: int):
def read(self, chunk_size: int, chunker: str):
file = PdfReader(self.raw_data_location)
def get_text():
@ -13,7 +13,8 @@ class PdfDocument(Document):
page_text = page.extract_text()
yield page_text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read()

View file

@ -1,10 +1,10 @@
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from .ChunkerMapping import ChunkerConfig
class TextDocument(Document):
type: str = "text"
def read(self, chunk_size: int):
def read(self, chunk_size: int, chunker: str):
def get_text():
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
while True:
@ -15,6 +15,8 @@ class TextDocument(Document):
yield text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read()

View file

@ -1,7 +1,7 @@
from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'):
for document in documents:
for document_chunk in document.read(chunk_size = chunk_size):
for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker):
yield document_chunk

View file

@ -31,7 +31,7 @@ def test_AudioDocument():
)
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64)
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
):
assert (
ground_truth["word_count"] == paragraph_data.word_count

View file

@ -21,7 +21,7 @@ def test_ImageDocument():
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64)
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
):
assert (
ground_truth["word_count"] == paragraph_data.word_count

View file

@ -22,7 +22,7 @@ def test_PdfDocument():
)
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024)
GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker')
):
assert (
ground_truth["word_count"] == paragraph_data.word_count

View file

@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size):
)
for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size)
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker')
):
assert (
ground_truth["word_count"] == paragraph_data.word_count