feat: outsources chunking parameters to extract chunk from documents … (#289)

* feat: outsources chunking parameters to extract chunk from documents task
This commit is contained in:
hajdul88 2024-12-17 11:31:31 +01:00 committed by GitHub
parent bfa0f06fb4
commit 9e7ab6492a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 40 additions and 20 deletions

View file

@ -1,6 +1,6 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class AudioDocument(Document): class AudioDocument(Document):
type: str = "audio" type: str = "audio"
@ -9,11 +9,12 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
return(result.text) return(result.text)
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
# Transcribe the audio file # Transcribe the audio file
text = self.create_transcript() text = self.create_transcript()
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
yield from chunker.read() yield from chunker.read()

View file

@ -0,0 +1,15 @@
from cognee.modules.chunking.TextChunker import TextChunker
class ChunkerConfig:
chunker_mapping = {
"text_chunker": TextChunker
}
@classmethod
def get_chunker(cls, chunker_name: str):
chunker_class = cls.chunker_mapping.get(chunker_name)
if chunker_class is None:
raise NotImplementedError(
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
)
return chunker_class

View file

@ -13,5 +13,5 @@ class Document(DataPoint):
"type": "Document" "type": "Document"
} }
def read(self, chunk_size: int) -> str: def read(self, chunk_size: int, chunker = str) -> str:
pass pass

View file

@ -1,6 +1,6 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class ImageDocument(Document): class ImageDocument(Document):
type: str = "image" type: str = "image"
@ -10,10 +10,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return(result.choices[0].message.content) return(result.choices[0].message.content)
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
yield from chunker.read() yield from chunker.read()

View file

@ -1,11 +1,11 @@
from pypdf import PdfReader from pypdf import PdfReader
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -13,7 +13,8 @@ class PdfDocument(Document):
page_text = page.extract_text() page_text = page.extract_text()
yield page_text yield page_text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read() yield from chunker.read()

View file

@ -1,10 +1,10 @@
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
def get_text(): def get_text():
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file: with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
while True: while True:
@ -15,6 +15,8 @@ class TextDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read() yield from chunker.read()

View file

@ -1,7 +1,7 @@
from cognee.modules.data.processing.document_types.Document import Document from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024): async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'):
for document in documents: for document in documents:
for document_chunk in document.read(chunk_size = chunk_size): for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker):
yield document_chunk yield document_chunk

View file

@ -31,7 +31,7 @@ def test_AudioDocument():
) )
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64) GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count

View file

@ -21,7 +21,7 @@ def test_ImageDocument():
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64) GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count

View file

@ -22,7 +22,7 @@ def test_PdfDocument():
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024) GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count

View file

@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size):
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size) GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count