Implement Advanced PDF Loader with unstructured library support
- Added AdvancedPdfLoader class for enhanced PDF processing using the unstructured library. - Integrated fallback mechanism to PyPdfLoader in case of unstructured library import failure or exceptions. - Updated supported loaders to include AdvancedPdfLoader. - Added unit tests for AdvancedPdfLoader to ensure functionality and error handling. - Updated poetry.lock and pyproject.toml to include new dependencies and versions. Signed-off-by: EricXiao <taoiaox@gmail.com>
This commit is contained in:
parent
1a4061a009
commit
6107cb47ca
8 changed files with 2148 additions and 141 deletions
|
|
@ -27,6 +27,7 @@ class LoaderEngine:
|
|||
|
||||
self.default_loader_priority = [
|
||||
"text_loader",
|
||||
"advanced_pdf_loader",
|
||||
"pypdf_loader",
|
||||
"image_loader",
|
||||
"audio_loader",
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@ This module contains loaders that depend on external libraries:
|
|||
These loaders are optional and only available if their dependencies are installed.
|
||||
"""
|
||||
|
||||
from .advanced_pdf_loader import AdvancedPdfLoader
|
||||
from .pypdf_loader import PyPdfLoader
|
||||
|
||||
__all__ = ["PyPdfLoader"]
|
||||
__all__ = ["AdvancedPdfLoader", "PyPdfLoader"]
|
||||
|
||||
# Conditional imports based on dependency availability
|
||||
try:
|
||||
|
|
|
|||
246
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
246
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""Advanced PDF loader leveraging unstructured for layout-aware extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PageBuffer:
|
||||
page_num: Optional[int]
|
||||
segments: List[str]
|
||||
|
||||
|
||||
class AdvancedPdfLoader(LoaderInterface):
|
||||
"""
|
||||
PDF loader using unstructured library.
|
||||
|
||||
Extracts text content, images, tables from PDF files page by page, providing
|
||||
structured page information and handling PDF-specific errors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return ["pdf"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
return ["application/pdf"]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
return "advanced_pdf_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
"""Check if file can be handled by this loader."""
|
||||
# Check file extension
|
||||
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, strategy: str = "auto", **kwargs: Any) -> str:
|
||||
"""Load PDF file using unstructured library. If Exception occurs, fallback to PyPDFLoader.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
|
||||
**kwargs: Additional arguments passed to unstructured partition
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content and metadata
|
||||
|
||||
"""
|
||||
try:
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, using PyPDF fallback."
|
||||
)
|
||||
|
||||
return await self._fallback(file_path, **kwargs)
|
||||
|
||||
try:
|
||||
logger.info(f"Processing PDF: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
|
||||
# Name ingested file of current loader based on original file content hash
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
# Set partitioning parameters
|
||||
partition_kwargs: Dict[str, Any] = {
|
||||
"filename": file_path,
|
||||
"strategy": strategy,
|
||||
"infer_table_structure": True,
|
||||
"include_page_breaks": False,
|
||||
"include_metadata": True,
|
||||
**kwargs,
|
||||
}
|
||||
# Use partition to extract elements
|
||||
elements = partition_pdf(**partition_kwargs)
|
||||
|
||||
# Process elements into text content
|
||||
page_contents = self._format_elements_by_page(elements)
|
||||
|
||||
# Check if there is any content
|
||||
if not page_contents:
|
||||
logger.warning(
|
||||
"AdvancedPdfLoader returned no content. Falling back to PyPDF loader."
|
||||
)
|
||||
return await self._fallback(file_path, **kwargs)
|
||||
|
||||
# Combine all page outputs
|
||||
full_content = "\n".join(page_contents)
|
||||
|
||||
# Store the content
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, full_content)
|
||||
|
||||
return full_file_path
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to process PDF with AdvancedPdfLoader: %s", exc)
|
||||
return await self._fallback(file_path, **kwargs)
|
||||
|
||||
async def _fallback(self, file_path: str, **kwargs: Any) -> str:
|
||||
logger.info("Falling back to PyPDF loader for %s", file_path)
|
||||
fallback_loader = PyPdfLoader()
|
||||
return await fallback_loader.load(file_path, **kwargs)
|
||||
|
||||
def _format_elements_by_page(self, elements: List[Any]) -> List[str]:
|
||||
"""Format elements by page."""
|
||||
page_buffers: List[_PageBuffer] = []
|
||||
current_buffer = _PageBuffer(page_num=None, segments=[])
|
||||
|
||||
for element in elements:
|
||||
element_dict = self._safe_to_dict(element)
|
||||
metadata = element_dict.get("metadata", {})
|
||||
page_num = metadata.get("page_number")
|
||||
|
||||
if current_buffer.page_num != page_num:
|
||||
if current_buffer.segments:
|
||||
page_buffers.append(current_buffer)
|
||||
current_buffer = _PageBuffer(page_num=page_num, segments=[])
|
||||
|
||||
formatted = self._format_element(element_dict)
|
||||
|
||||
if formatted:
|
||||
current_buffer.segments.append(formatted)
|
||||
|
||||
if current_buffer.segments:
|
||||
page_buffers.append(current_buffer)
|
||||
|
||||
page_contents: List[str] = []
|
||||
for buffer in page_buffers:
|
||||
header = f"Page {buffer.page_num}:\n" if buffer.page_num is not None else "Page:"
|
||||
content = header + "\n\n".join(buffer.segments) + "\n"
|
||||
page_contents.append(str(content))
|
||||
return page_contents
|
||||
|
||||
def _format_element(
|
||||
self,
|
||||
element: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Format element."""
|
||||
element_type = element.get("type")
|
||||
text = self._clean_text(element.get("text", ""))
|
||||
metadata = element.get("metadata", {})
|
||||
|
||||
if element_type.lower() == "table":
|
||||
return self._format_table_element(element) or text
|
||||
|
||||
if element_type.lower() == "image":
|
||||
description = text or self._format_image_element(metadata)
|
||||
return description
|
||||
|
||||
# Ignore header and footer
|
||||
if element_type.lower() in ["header", "footer"]:
|
||||
pass
|
||||
|
||||
return text
|
||||
|
||||
def _format_table_element(self, element: Dict[str, Any]) -> str:
|
||||
"""Format table element."""
|
||||
metadata = element.get("metadata", {})
|
||||
text = self._clean_text(element.get("text", ""))
|
||||
table_html = metadata.get("text_as_html")
|
||||
|
||||
if table_html:
|
||||
return table_html.strip()
|
||||
|
||||
return text
|
||||
|
||||
def _format_image_element(self, metadata: Dict[str, Any]) -> str:
|
||||
"""Format image."""
|
||||
placeholder = "[Image omitted]"
|
||||
image_text = placeholder
|
||||
coordinates = metadata.get("coordinates", {})
|
||||
points = coordinates.get("points") if isinstance(coordinates, dict) else None
|
||||
if points and isinstance(points, tuple) and len(points) == 4:
|
||||
leftup = points[0]
|
||||
rightdown = points[3]
|
||||
if (
|
||||
isinstance(leftup, tuple)
|
||||
and isinstance(rightdown, tuple)
|
||||
and len(leftup) == 2
|
||||
and len(rightdown) == 2
|
||||
):
|
||||
image_text = f"{placeholder} (bbox=({leftup[0]}, {leftup[1]}, {rightdown[0]}, {rightdown[1]}))"
|
||||
|
||||
layout_width = coordinates.get("layout_width")
|
||||
layout_height = coordinates.get("layout_height")
|
||||
system = coordinates.get("system")
|
||||
if layout_width and layout_height and system:
|
||||
image_text = (
|
||||
image_text
|
||||
+ f", system={system}, layout_width={layout_width}, layout_height={layout_height}))"
|
||||
)
|
||||
|
||||
return image_text
|
||||
|
||||
def _safe_to_dict(self, element: Any) -> Dict[str, Any]:
|
||||
"""Safe to dict."""
|
||||
try:
|
||||
if hasattr(element, "to_dict"):
|
||||
return element.to_dict()
|
||||
except Exception:
|
||||
pass
|
||||
fallback_type = getattr(element, "category", None)
|
||||
if not fallback_type:
|
||||
fallback_type = getattr(element, "__class__", type("", (), {})).__name__
|
||||
|
||||
return {
|
||||
"type": fallback_type,
|
||||
"text": getattr(element, "text", ""),
|
||||
"metadata": getattr(element, "metadata", {}),
|
||||
}
|
||||
|
||||
def _clean_text(self, value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).replace("\xa0", " ").strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = AdvancedPdfLoader()
|
||||
asyncio.run(
|
||||
loader.load(
|
||||
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
||||
)
|
||||
)
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
from cognee.infrastructure.loaders.external import PyPdfLoader
|
||||
from cognee.infrastructure.loaders.external import AdvancedPdfLoader, PyPdfLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
|
||||
|
||||
# Registry for loader implementations
|
||||
supported_loaders = {
|
||||
AdvancedPdfLoader.loader_name: AdvancedPdfLoader,
|
||||
PyPdfLoader.loader_name: PyPdfLoader,
|
||||
TextLoader.loader_name: TextLoader,
|
||||
ImageLoader.loader_name: ImageLoader,
|
||||
|
|
|
|||
164
cognee/tests/test_advanced_pdf_loader.py
Normal file
164
cognee/tests/test_advanced_pdf_loader.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import sys
|
||||
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
|
||||
import pytest
|
||||
|
||||
from cognee.infrastructure.loaders.external.advanced_pdf_loader import AdvancedPdfLoader
|
||||
|
||||
advanced_pdf_loader_module = sys.modules.get(
|
||||
"cognee.infrastructure.loaders.external.advanced_pdf_loader"
|
||||
)
|
||||
|
||||
|
||||
class MockElement:
|
||||
def __init__(self, category, text, metadata):
|
||||
self.category = category
|
||||
self.text = text
|
||||
self.metadata = metadata
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"type": self.category,
|
||||
"text": self.text,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
return AdvancedPdfLoader()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extension, mime_type, expected",
|
||||
[
|
||||
("pdf", "application/pdf", True),
|
||||
("txt", "text/plain", False),
|
||||
("pdf", "text/plain", False),
|
||||
("doc", "application/pdf", False),
|
||||
],
|
||||
)
|
||||
def test_can_handle(loader, extension, mime_type, expected):
|
||||
"""Test can_handle method can correctly identify PDF files"""
|
||||
assert loader.can_handle(extension, mime_type) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
|
||||
@patch(
|
||||
"unstructured.partition.pdf.partition_pdf",
|
||||
side_effect=ImportError("unstructured not installed"),
|
||||
)
|
||||
async def test_load_fallback_on_import_error(mock_partition_pdf, mock_pypdf_loader, loader):
|
||||
"""Test fallback to PyPdfLoader when unstructured is not installed"""
|
||||
# Prepare Mock
|
||||
mock_fallback_instance = MagicMock()
|
||||
mock_fallback_instance.load = AsyncMock(return_value="/fake/path/fallback.txt")
|
||||
mock_pypdf_loader.return_value = mock_fallback_instance
|
||||
test_file_path = "/fake/path/to/document.pdf"
|
||||
|
||||
# Run
|
||||
result_path = await loader.load(test_file_path)
|
||||
|
||||
# Assert
|
||||
assert result_path == "/fake/path/fallback.txt"
|
||||
mock_partition_pdf.assert_not_called() # partition_pdf should not be called
|
||||
mock_fallback_instance.load.assert_awaited_once_with(test_file_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
|
||||
@patch(
|
||||
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_storage_config")
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_storage")
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
|
||||
@patch("unstructured.partition.pdf.partition_pdf")
|
||||
async def test_load_success_with_unstructured(
|
||||
mock_partition_pdf,
|
||||
mock_pypdf_loader,
|
||||
mock_get_file_storage,
|
||||
mock_get_storage_config,
|
||||
mock_get_file_metadata,
|
||||
mock_open,
|
||||
loader,
|
||||
):
|
||||
"""Test the main flow of using unstructured to successfully process PDF"""
|
||||
# Prepare Mock data and objects
|
||||
mock_elements = [
|
||||
MockElement(
|
||||
category="Title", text="Attention Is All You Need", metadata={"page_number": 1}
|
||||
),
|
||||
MockElement(
|
||||
category="NarrativeText",
|
||||
text="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks.",
|
||||
metadata={"page_number": 1},
|
||||
),
|
||||
MockElement(
|
||||
category="Table",
|
||||
text="This is a table.",
|
||||
metadata={"page_number": 2, "text_as_html": "<table><tr><td>Data</td></tr></table>"},
|
||||
),
|
||||
]
|
||||
mock_pypdf_loader.return_value.load = AsyncMock(return_value="/fake/path/fallback.txt")
|
||||
mock_partition_pdf.return_value = mock_elements
|
||||
mock_get_file_metadata.return_value = {"content_hash": "abc123def456"}
|
||||
|
||||
mock_storage_instance = MagicMock()
|
||||
mock_storage_instance.store = AsyncMock(return_value="/stored/text_abc123def456.txt")
|
||||
mock_get_file_storage.return_value = mock_storage_instance
|
||||
|
||||
mock_get_storage_config.return_value = {"data_root_directory": "/fake/data/root"}
|
||||
test_file_path = "/fake/path/document.pdf"
|
||||
|
||||
# Run
|
||||
|
||||
result_path = await loader.load(test_file_path)
|
||||
|
||||
# Assert
|
||||
assert result_path == "/stored/text_abc123def456.txt"
|
||||
|
||||
# Verify partition_pdf is called with the correct parameters
|
||||
mock_partition_pdf.assert_called_once()
|
||||
call_args, call_kwargs = mock_partition_pdf.call_args
|
||||
assert call_kwargs.get("filename") == test_file_path
|
||||
assert call_kwargs.get("strategy") == "auto" # Default strategy
|
||||
|
||||
# Verify the stored content is correct
|
||||
expected_content = "Page 1:\nAttention Is All You Need\n\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks.\n\nPage 2:\n<table><tr><td>Data</td></tr></table>\n"
|
||||
mock_storage_instance.store.assert_awaited_once_with("text_abc123def456.txt", expected_content)
|
||||
|
||||
# Verify fallback is not called
|
||||
mock_pypdf_loader.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
|
||||
@patch(
|
||||
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
|
||||
@patch(
|
||||
"unstructured.partition.pdf.partition_pdf",
|
||||
side_effect=Exception("Unstructured failed!"),
|
||||
)
|
||||
async def test_load_fallback_on_unstructured_exception(
|
||||
mock_partition_pdf, mock_pypdf_loader, mock_get_file_metadata, mock_open, loader
|
||||
):
|
||||
"""Test fallback to PyPdfLoader when unstructured throws an exception"""
|
||||
# Prepare Mock
|
||||
mock_fallback_instance = MagicMock()
|
||||
mock_fallback_instance.load = AsyncMock(return_value="/fake/path/fallback.txt")
|
||||
mock_pypdf_loader.return_value = mock_fallback_instance
|
||||
mock_get_file_metadata.return_value = {"content_hash": "anyhash"}
|
||||
test_file_path = "/fake/path/document.pdf"
|
||||
|
||||
# Run
|
||||
result_path = await loader.load(test_file_path)
|
||||
|
||||
# Assert
|
||||
assert result_path == "/fake/path/fallback.txt"
|
||||
mock_partition_pdf.assert_called_once() # Verify partition_pdf is called
|
||||
mock_fallback_instance.load.assert_awaited_once_with(test_file_path)
|
||||
1031
poetry.lock
generated
1031
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -42,6 +42,7 @@ dependencies = [
|
|||
"aiofiles>=23.2.1,<24.0.0",
|
||||
"rdflib>=7.1.4,<7.2.0",
|
||||
"pypdf>=4.1.0,<7.0.0",
|
||||
"unstructured[pdf]>=0.18.1,<19",
|
||||
"jinja2>=3.1.3,<4",
|
||||
"matplotlib>=3.8.3,<4",
|
||||
"networkx>=3.4.2,<4",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue