feat: outsources chunking parameters to extract chunk from documents … (#289)
* feat: outsources chunking parameters to extract chunk from documents task
This commit is contained in:
parent
bfa0f06fb4
commit
9e7ab6492a
11 changed files with 40 additions and 20 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
|
||||||
class AudioDocument(Document):
|
class AudioDocument(Document):
|
||||||
type: str = "audio"
|
type: str = "audio"
|
||||||
|
|
@ -9,11 +9,12 @@ class AudioDocument(Document):
|
||||||
result = get_llm_client().create_transcript(self.raw_data_location)
|
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||||
return(result.text)
|
return(result.text)
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int, chunker: str):
|
||||||
# Transcribe the audio file
|
# Transcribe the audio file
|
||||||
|
|
||||||
text = self.create_transcript()
|
text = self.create_transcript()
|
||||||
|
|
||||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
|
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
|
|
||||||
|
class ChunkerConfig:
|
||||||
|
chunker_mapping = {
|
||||||
|
"text_chunker": TextChunker
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_chunker(cls, chunker_name: str):
|
||||||
|
chunker_class = cls.chunker_mapping.get(chunker_name)
|
||||||
|
if chunker_class is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
|
||||||
|
)
|
||||||
|
return chunker_class
|
||||||
|
|
@ -13,5 +13,5 @@ class Document(DataPoint):
|
||||||
"type": "Document"
|
"type": "Document"
|
||||||
}
|
}
|
||||||
|
|
||||||
def read(self, chunk_size: int) -> str:
|
def read(self, chunk_size: int, chunker = str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
|
||||||
class ImageDocument(Document):
|
class ImageDocument(Document):
|
||||||
type: str = "image"
|
type: str = "image"
|
||||||
|
|
@ -10,10 +10,11 @@ class ImageDocument(Document):
|
||||||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||||
return(result.choices[0].message.content)
|
return(result.choices[0].message.content)
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int, chunker: str):
|
||||||
# Transcribe the image file
|
# Transcribe the image file
|
||||||
text = self.transcribe_image()
|
text = self.transcribe_image()
|
||||||
|
|
||||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
|
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
|
||||||
class PdfDocument(Document):
|
class PdfDocument(Document):
|
||||||
type: str = "pdf"
|
type: str = "pdf"
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int, chunker: str):
|
||||||
file = PdfReader(self.raw_data_location)
|
file = PdfReader(self.raw_data_location)
|
||||||
|
|
||||||
def get_text():
|
def get_text():
|
||||||
|
|
@ -13,7 +13,8 @@ class PdfDocument(Document):
|
||||||
page_text = page.extract_text()
|
page_text = page.extract_text()
|
||||||
yield page_text
|
yield page_text
|
||||||
|
|
||||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
|
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
|
||||||
class TextDocument(Document):
|
class TextDocument(Document):
|
||||||
type: str = "text"
|
type: str = "text"
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int, chunker: str):
|
||||||
def get_text():
|
def get_text():
|
||||||
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
|
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -15,6 +15,8 @@ class TextDocument(Document):
|
||||||
|
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
|
|
||||||
|
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from cognee.modules.data.processing.document_types.Document import Document
|
from cognee.modules.data.processing.document_types.Document import Document
|
||||||
|
|
||||||
|
|
||||||
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
|
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'):
|
||||||
for document in documents:
|
for document in documents:
|
||||||
for document_chunk in document.read(chunk_size = chunk_size):
|
for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker):
|
||||||
yield document_chunk
|
yield document_chunk
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def test_AudioDocument():
|
||||||
)
|
)
|
||||||
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=64)
|
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
ground_truth["word_count"] == paragraph_data.word_count
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ def test_ImageDocument():
|
||||||
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
||||||
|
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=64)
|
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
ground_truth["word_count"] == paragraph_data.word_count
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ def test_PdfDocument():
|
||||||
)
|
)
|
||||||
|
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=1024)
|
GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker')
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
ground_truth["word_count"] == paragraph_data.word_count
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size):
|
||||||
)
|
)
|
||||||
|
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size)
|
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker')
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
ground_truth["word_count"] == paragraph_data.word_count
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue