feat: outsources chunking parameters to extract chunk from documents … (#289)
* feat: outsources chunking parameters to extract chunk from documents task
This commit is contained in:
parent
bfa0f06fb4
commit
9e7ab6492a
11 changed files with 40 additions and 20 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class AudioDocument(Document):
|
||||
type: str = "audio"
|
||||
|
|
@ -9,11 +9,12 @@ class AudioDocument(Document):
|
|||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
return(result.text)
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
# Transcribe the audio file
|
||||
|
||||
text = self.create_transcript()
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
|
||||
class ChunkerConfig:
|
||||
chunker_mapping = {
|
||||
"text_chunker": TextChunker
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_chunker(cls, chunker_name: str):
|
||||
chunker_class = cls.chunker_mapping.get(chunker_name)
|
||||
if chunker_class is None:
|
||||
raise NotImplementedError(
|
||||
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
|
||||
)
|
||||
return chunker_class
|
||||
|
|
@ -13,5 +13,5 @@ class Document(DataPoint):
|
|||
"type": "Document"
|
||||
}
|
||||
|
||||
def read(self, chunk_size: int) -> str:
|
||||
pass
|
||||
def read(self, chunk_size: int, chunker = str) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class ImageDocument(Document):
|
||||
type: str = "image"
|
||||
|
|
@ -10,10 +10,11 @@ class ImageDocument(Document):
|
|||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return(result.choices[0].message.content)
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from pypdf import PdfReader
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
@ -13,7 +13,8 @@ class PdfDocument(Document):
|
|||
page_text = page.extract_text()
|
||||
yield page_text
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
def get_text():
|
||||
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
|
||||
while True:
|
||||
|
|
@ -15,6 +15,8 @@ class TextDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
|
||||
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
|
||||
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'):
|
||||
for document in documents:
|
||||
for document_chunk in document.read(chunk_size = chunk_size):
|
||||
for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker):
|
||||
yield document_chunk
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def test_AudioDocument():
|
|||
)
|
||||
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64)
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def test_ImageDocument():
|
|||
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64)
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def test_PdfDocument():
|
|||
)
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=1024)
|
||||
GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size):
|
|||
)
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size)
|
||||
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue