diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py index 9e3420ab8..d881514a2 100644 --- a/cognee/tasks/documents/classify_documents.py +++ b/cognee/tasks/documents/classify_documents.py @@ -1,12 +1,16 @@ from cognee.modules.data.models import Data from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument +EXTENSION_TO_DOCUMENT_CLASS = { + "pdf": PdfDocument, + "audio": AudioDocument, + "image": ImageDocument, + "txt": TextDocument +} + def classify_documents(data_documents: list[Data]) -> list[Document]: documents = [ - PdfDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else - AudioDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else - ImageDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else - TextDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) + EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location, name=data_item.name) for data_item in data_documents ] return documents