openrag/src/utils/document_processing.py
2025-09-04 20:56:47 -04:00

308 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import hashlib
import os
import sys
import platform
from collections import defaultdict
from .gpu_detection import detect_gpu_devices
from utils.logging_config import get_logger
logger = get_logger(__name__)
# Global converter cache for worker processes
_worker_converter = None
def get_worker_converter():
"""Get or create a DocumentConverter instance for this worker process"""
global _worker_converter
if _worker_converter is None:
from docling.document_converter import DocumentConverter
# Configure GPU settings for this worker
has_gpu_devices, _ = detect_gpu_devices()
if not has_gpu_devices:
# Force CPU-only mode in subprocess
os.environ["USE_CPU_ONLY"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "0"
os.environ["TORCH_USE_CUDA_DSA"] = "0"
# Try to disable CUDA in torch if available
try:
import torch
torch.cuda.is_available = lambda: False
except ImportError:
pass
else:
# GPU mode - let libraries use GPU if available
os.environ.pop("USE_CPU_ONLY", None)
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = (
"1" # Still disable progress bars
)
logger.info(
"Initializing DocumentConverter in worker process", worker_pid=os.getpid()
)
_worker_converter = DocumentConverter()
logger.info("DocumentConverter ready in worker process", worker_pid=os.getpid())
return _worker_converter
def extract_relevant(doc_dict: dict) -> dict:
"""
Given the full export_to_dict() result:
- Grabs origin metadata (hash, filename, mimetype)
- Finds every text fragment in `texts`, groups them by page_no
- Flattens tables in `tables` into tab-separated text, grouping by row
- Concatenates each page's fragments and each table into its own chunk
Returns a slimmed dict ready for indexing, with each chunk under "text".
"""
origin = doc_dict.get("origin", {})
chunks = []
# 1) process free-text fragments
page_texts = defaultdict(list)
for txt in doc_dict.get("texts", []):
prov = txt.get("prov", [])
page_no = prov[0].get("page_no") if prov else None
if page_no is not None:
page_texts[page_no].append(txt.get("text", "").strip())
for page in sorted(page_texts):
chunks.append(
{"page": page, "type": "text", "text": "\n".join(page_texts[page])}
)
# 2) process tables
for t_idx, table in enumerate(doc_dict.get("tables", [])):
prov = table.get("prov", [])
page_no = prov[0].get("page_no") if prov else None
# group cells by their row index
rows = defaultdict(list)
for cell in table.get("data").get("table_cells", []):
r = cell.get("start_row_offset_idx")
c = cell.get("start_col_offset_idx")
text = cell.get("text", "").strip()
rows[r].append((c, text))
# build a tabseparated line for each row, in order
flat_rows = []
for r in sorted(rows):
cells = [txt for _, txt in sorted(rows[r], key=lambda x: x[0])]
flat_rows.append("\t".join(cells))
chunks.append(
{
"page": page_no,
"type": "table",
"table_index": t_idx,
"text": "\n".join(flat_rows),
}
)
return {
"id": origin.get("binary_hash"),
"filename": origin.get("filename"),
"mimetype": origin.get("mimetype"),
"chunks": chunks,
}
def process_document_sync(file_path: str):
"""Synchronous document processing function for multiprocessing"""
import traceback
import psutil
import sys
from collections import defaultdict
process = psutil.Process()
start_memory = process.memory_info().rss / 1024 / 1024 # MB
try:
logger.info(
"Starting document processing",
worker_pid=os.getpid(),
file_path=file_path,
initial_memory_mb=f"{start_memory:.1f}",
)
# Check file size
try:
file_size = os.path.getsize(file_path) / 1024 / 1024 # MB
logger.info(
"File size determined",
worker_pid=os.getpid(),
file_size_mb=f"{file_size:.1f}",
)
except OSError as e:
logger.warning("Cannot get file size", worker_pid=os.getpid(), error=str(e))
file_size = 0
# Get the cached converter for this worker
try:
logger.info("Getting document converter", worker_pid=os.getpid())
converter = get_worker_converter()
memory_after_converter = process.memory_info().rss / 1024 / 1024
logger.info(
"Memory after converter init",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_converter:.1f}",
)
except Exception as e:
logger.error(
"Failed to initialize converter", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc()
raise
# Compute file hash
try:
logger.info("Computing file hash", worker_pid=os.getpid())
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
chunk = f.read(1 << 20)
if not chunk:
break
sha256.update(chunk)
file_hash = sha256.hexdigest()
logger.info(
"File hash computed",
worker_pid=os.getpid(),
file_hash_prefix=file_hash[:12],
)
except Exception as e:
logger.error(
"Failed to compute file hash", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc()
raise
# Convert with docling
try:
logger.info("Starting docling conversion", worker_pid=os.getpid())
memory_before_convert = process.memory_info().rss / 1024 / 1024
logger.info(
"Memory before conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_before_convert:.1f}",
)
result = converter.convert(file_path)
memory_after_convert = process.memory_info().rss / 1024 / 1024
logger.info(
"Memory after conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_convert:.1f}",
)
logger.info("Docling conversion completed", worker_pid=os.getpid())
full_doc = result.document.export_to_dict()
memory_after_export = process.memory_info().rss / 1024 / 1024
logger.info(
"Memory after export",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_export:.1f}",
)
except Exception as e:
current_memory = process.memory_info().rss / 1024 / 1024
logger.error(
"Failed during docling conversion",
worker_pid=os.getpid(),
error=str(e),
current_memory_mb=f"{current_memory:.1f}",
)
traceback.print_exc()
raise
# Extract relevant content (same logic as extract_relevant)
try:
logger.info("Extracting relevant content", worker_pid=os.getpid())
origin = full_doc.get("origin", {})
texts = full_doc.get("texts", [])
logger.info(
"Found text fragments",
worker_pid=os.getpid(),
fragment_count=len(texts),
)
page_texts = defaultdict(list)
for txt in texts:
prov = txt.get("prov", [])
page_no = prov[0].get("page_no") if prov else None
if page_no is not None:
page_texts[page_no].append(txt.get("text", "").strip())
chunks = []
for page in sorted(page_texts):
joined = "\n".join(page_texts[page])
chunks.append({"page": page, "text": joined})
logger.info(
"Created chunks from pages",
worker_pid=os.getpid(),
chunk_count=len(chunks),
page_count=len(page_texts),
)
except Exception as e:
logger.error(
"Failed during content extraction", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc()
raise
final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory
logger.info(
"Document processing completed successfully",
worker_pid=os.getpid(),
final_memory_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
)
return {
"id": file_hash,
"filename": origin.get("filename"),
"mimetype": origin.get("mimetype"),
"chunks": chunks,
"file_path": file_path,
}
except Exception as e:
final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory
logger.error(
"FATAL ERROR in process_document_sync",
worker_pid=os.getpid(),
file_path=file_path,
python_version=sys.version,
memory_at_crash_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
error_type=type(e).__name__,
error=str(e),
)
logger.error("Full traceback:", worker_pid=os.getpid())
traceback.print_exc()
# Try to get more system info before crashing
try:
import platform
logger.error(
"System info",
worker_pid=os.getpid(),
system=f"{platform.system()} {platform.release()}",
architecture=platform.machine(),
)
except:
pass
# Re-raise to trigger BrokenProcessPool in main process
raise