From 268d027c8f2abae7339b4c7d33642c3135c56e7a Mon Sep 17 00:00:00 2001 From: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:42:11 +0100 Subject: [PATCH] feat: Use threading in the standard pipeline and move old behavior to legacy (#2452) * rename standard to legacy Signed-off-by: Michele Dolfi * remove old standard pipeline Signed-off-by: Michele Dolfi * move threaded to standard Signed-off-by: Michele Dolfi * add backwards compatible threaded pipeline Signed-off-by: Michele Dolfi * Updates for threaded pipeline to lower memory requirements Signed-off-by: Christoph Auer * updating deps seem to remove the corrupted double-linked list error Signed-off-by: Michele Dolfi * update pinning Signed-off-by: Michele Dolfi * use main lock Signed-off-by: Michele Dolfi * add more threadsafe blocks Signed-off-by: Michele Dolfi * rename batch_timeout_seconds Signed-off-by: Michele Dolfi --------- Signed-off-by: Michele Dolfi Signed-off-by: Christoph Auer Co-authored-by: Christoph Auer --- docling/datamodel/pipeline_options.py | 23 +- docling/models/layout_model.py | 4 + .../pipeline/legacy_standard_pdf_pipeline.py | 242 +++++++ docling/pipeline/standard_pdf_pipeline.py | 679 +++++++++++++++--- .../threaded_standard_pdf_pipeline.py | 648 +---------------- tests/test_threaded_pipeline.py | 2 +- uv.lock | 5 + 7 files changed, 851 insertions(+), 752 deletions(-) create mode 100644 docling/pipeline/legacy_standard_pdf_pipeline.py diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 11448437..dc00a5cf 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -361,15 +361,7 @@ class PdfPipelineOptions(PaginatedPipelineOptions): generate_parsed_pages: bool = False - -class ProcessingPipeline(str, Enum): - STANDARD = "standard" - VLM = "vlm" - ASR = "asr" - - -class ThreadedPdfPipelineOptions(PdfPipelineOptions): - """Pipeline options for the threaded PDF pipeline with batching and backpressure control""" + ### Arguments for threaded PDF pipeline with batching and backpressure control # Batch sizes for different stages ocr_batch_size: int = 4 @@ -377,7 +369,18 @@ class ThreadedPdfPipelineOptions(PdfPipelineOptions): table_batch_size: int = 4 # Timing control - batch_timeout_seconds: float = 2.0 + batch_polling_interval_seconds: float = 0.5 # Backpressure and queue control queue_max_size: int = 100 + + +class ProcessingPipeline(str, Enum): + LEGACY = "legacy" + STANDARD = "standard" + VLM = "vlm" + ASR = "asr" + + +class ThreadedPdfPipelineOptions(PdfPipelineOptions): + """Pipeline options for the threaded PDF pipeline with batching and backpressure control""" diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 1879da40..ea5cfb7c 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -167,6 +167,10 @@ class LayoutModel(BasePageModel): valid_pages.append(page) valid_page_images.append(page_image) + print(f"{len(pages)=}, {pages[0].page_no}-{pages[-1].page_no}") + print(f"{len(valid_pages)=}") + print(f"{len(valid_page_images)=}") + # Process all valid pages with batch prediction batch_predictions = [] if valid_page_images: diff --git a/docling/pipeline/legacy_standard_pdf_pipeline.py b/docling/pipeline/legacy_standard_pdf_pipeline.py new file mode 100644 index 00000000..edcf04f4 --- /dev/null +++ b/docling/pipeline/legacy_standard_pdf_pipeline.py @@ -0,0 +1,242 @@ +import logging +import warnings +from pathlib import Path +from typing import Optional, cast + +import numpy as np +from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem + +from docling.backend.abstract_backend import AbstractDocumentBackend +from docling.backend.pdf_backend import PdfDocumentBackend +from docling.datamodel.base_models import AssembledUnit, Page +from docling.datamodel.document import ConversionResult +from docling.datamodel.layout_model_specs import LayoutModelConfig +from docling.datamodel.pipeline_options import PdfPipelineOptions +from docling.datamodel.settings import settings +from docling.models.base_ocr_model import BaseOcrModel +from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions +from docling.models.factories import get_ocr_factory +from docling.models.layout_model import LayoutModel +from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions +from docling.models.page_preprocessing_model import ( + PagePreprocessingModel, + PagePreprocessingOptions, +) +from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions +from docling.models.table_structure_model import TableStructureModel +from docling.pipeline.base_pipeline import PaginatedPipeline +from docling.utils.model_downloader import download_models +from docling.utils.profiling import ProfilingScope, TimeRecorder + +_log = logging.getLogger(__name__) + + +class LegacyStandardPdfPipeline(PaginatedPipeline): + def __init__(self, pipeline_options: PdfPipelineOptions): + super().__init__(pipeline_options) + self.pipeline_options: PdfPipelineOptions + + with warnings.catch_warnings(): # deprecated generate_table_images + warnings.filterwarnings("ignore", category=DeprecationWarning) + self.keep_images = ( + self.pipeline_options.generate_page_images + or self.pipeline_options.generate_picture_images + or self.pipeline_options.generate_table_images + ) + + self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions()) + + ocr_model = self.get_ocr_model(artifacts_path=self.artifacts_path) + + self.build_pipe = [ + # Pre-processing + PagePreprocessingModel( + options=PagePreprocessingOptions( + images_scale=pipeline_options.images_scale, + ) + ), + # OCR + ocr_model, + # Layout model + LayoutModel( + artifacts_path=self.artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + options=pipeline_options.layout_options, + ), + # Table structure model + TableStructureModel( + enabled=pipeline_options.do_table_structure, + artifacts_path=self.artifacts_path, + options=pipeline_options.table_structure_options, + accelerator_options=pipeline_options.accelerator_options, + ), + # Page assemble + PageAssembleModel(options=PageAssembleOptions()), + ] + + self.enrichment_pipe = [ + # Code Formula Enrichment Model + CodeFormulaModel( + enabled=pipeline_options.do_code_enrichment + or pipeline_options.do_formula_enrichment, + artifacts_path=self.artifacts_path, + options=CodeFormulaModelOptions( + do_code_enrichment=pipeline_options.do_code_enrichment, + do_formula_enrichment=pipeline_options.do_formula_enrichment, + ), + accelerator_options=pipeline_options.accelerator_options, + ), + *self.enrichment_pipe, + ] + + if ( + self.pipeline_options.do_formula_enrichment + or self.pipeline_options.do_code_enrichment + or self.pipeline_options.do_picture_classification + or self.pipeline_options.do_picture_description + ): + self.keep_backend = True + + @staticmethod + def download_models_hf( + local_dir: Optional[Path] = None, force: bool = False + ) -> Path: + warnings.warn( + "The usage of LegacyStandardPdfPipeline.download_models_hf() is deprecated " + "use instead the utility `docling-tools models download`, or " + "the upstream method docling.utils.models_downloader.download_all()", + DeprecationWarning, + stacklevel=3, + ) + + output_dir = download_models(output_dir=local_dir, force=force, progress=False) + return output_dir + + def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: + factory = get_ocr_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) + return factory.create_instance( + options=self.pipeline_options.ocr_options, + enabled=self.pipeline_options.do_ocr, + artifacts_path=artifacts_path, + accelerator_options=self.pipeline_options.accelerator_options, + ) + + def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: + with TimeRecorder(conv_res, "page_init"): + page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore + if page._backend is not None and page._backend.is_valid(): + page.size = page._backend.get_size() + + return page + + def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult: + all_elements = [] + all_headers = [] + all_body = [] + + with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT): + for p in conv_res.pages: + if p.assembled is not None: + for el in p.assembled.body: + all_body.append(el) + for el in p.assembled.headers: + all_headers.append(el) + for el in p.assembled.elements: + all_elements.append(el) + + conv_res.assembled = AssembledUnit( + elements=all_elements, headers=all_headers, body=all_body + ) + + conv_res.document = self.reading_order_model(conv_res) + + # Generate page images in the output + if self.pipeline_options.generate_page_images: + for page in conv_res.pages: + assert page.image is not None + page_no = page.page_no + 1 + conv_res.document.pages[page_no].image = ImageRef.from_pil( + page.image, dpi=int(72 * self.pipeline_options.images_scale) + ) + + # Generate images of the requested element types + with warnings.catch_warnings(): # deprecated generate_table_images + warnings.filterwarnings("ignore", category=DeprecationWarning) + if ( + self.pipeline_options.generate_picture_images + or self.pipeline_options.generate_table_images + ): + scale = self.pipeline_options.images_scale + for element, _level in conv_res.document.iterate_items(): + if not isinstance(element, DocItem) or len(element.prov) == 0: + continue + if ( + isinstance(element, PictureItem) + and self.pipeline_options.generate_picture_images + ) or ( + isinstance(element, TableItem) + and self.pipeline_options.generate_table_images + ): + page_ix = element.prov[0].page_no - 1 + page = next( + (p for p in conv_res.pages if p.page_no == page_ix), + cast("Page", None), + ) + assert page is not None + assert page.size is not None + assert page.image is not None + + crop_bbox = ( + element.prov[0] + .bbox.scaled(scale=scale) + .to_top_left_origin( + page_height=page.size.height * scale + ) + ) + + cropped_im = page.image.crop(crop_bbox.as_tuple()) + element.image = ImageRef.from_pil( + cropped_im, dpi=int(72 * scale) + ) + + # Aggregate confidence values for document: + if len(conv_res.pages) > 0: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=RuntimeWarning, + message="Mean of empty slice|All-NaN slice encountered", + ) + conv_res.confidence.layout_score = float( + np.nanmean( + [c.layout_score for c in conv_res.confidence.pages.values()] + ) + ) + conv_res.confidence.parse_score = float( + np.nanquantile( + [c.parse_score for c in conv_res.confidence.pages.values()], + q=0.1, # parse score should relate to worst 10% of pages. + ) + ) + conv_res.confidence.table_score = float( + np.nanmean( + [c.table_score for c in conv_res.confidence.pages.values()] + ) + ) + conv_res.confidence.ocr_score = float( + np.nanmean( + [c.ocr_score for c in conv_res.confidence.pages.values()] + ) + ) + + return conv_res + + @classmethod + def get_default_options(cls) -> PdfPipelineOptions: + return PdfPipelineOptions() + + @classmethod + def is_backend_supported(cls, backend: AbstractDocumentBackend): + return isinstance(backend, PdfDocumentBackend) diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 1722ca5b..82bf012f 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -1,19 +1,39 @@ +"""Thread-safe, production-ready PDF pipeline +================================================ +A self-contained, thread-safe PDF conversion pipeline exploiting parallelism between pipeline stages and models. + +* **Per-run isolation** - every :py:meth:`execute` call uses its own bounded queues and worker + threads so that concurrent invocations never share mutable state. +* **Deterministic run identifiers** - pages are tracked with an internal *run-id* instead of + relying on :pyfunc:`id`, which may clash after garbage collection. +* **Explicit back-pressure & shutdown** - producers block on full queues; queue *close()* + propagates downstream so stages terminate deterministically without sentinels. +* **Minimal shared state** - heavyweight models are initialised once per pipeline instance + and only read by worker threads; no runtime mutability is exposed. +* **Strict typing & clean API usage** - code is fully annotated and respects *coding_rules.md*. +""" + +from __future__ import annotations + +import itertools import logging +import threading +import time import warnings +from collections import defaultdict, deque +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, cast +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, cast import numpy as np from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend -from docling.datamodel.base_models import AssembledUnit, Page +from docling.datamodel.base_models import AssembledUnit, ConversionStatus, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.layout_model_specs import LayoutModelConfig -from docling.datamodel.pipeline_options import PdfPipelineOptions +from docling.datamodel.pipeline_options import ThreadedPdfPipelineOptions from docling.datamodel.settings import settings -from docling.models.base_ocr_model import BaseOcrModel from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions from docling.models.factories import get_ocr_factory from docling.models.layout_model import LayoutModel @@ -24,132 +44,588 @@ from docling.models.page_preprocessing_model import ( ) from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.table_structure_model import TableStructureModel -from docling.pipeline.base_pipeline import PaginatedPipeline -from docling.utils.model_downloader import download_models +from docling.pipeline.base_pipeline import ConvertPipeline from docling.utils.profiling import ProfilingScope, TimeRecorder +from docling.utils.utils import chunkify _log = logging.getLogger(__name__) +# ────────────────────────────────────────────────────────────────────────────── +# Helper data structures +# ────────────────────────────────────────────────────────────────────────────── -class StandardPdfPipeline(PaginatedPipeline): - def __init__(self, pipeline_options: PdfPipelineOptions): + +@dataclass +class ThreadedItem: + """Envelope that travels between pipeline stages.""" + + payload: Optional[Page] + run_id: int # Unique per *execute* call, monotonic across pipeline instance + page_no: int + conv_res: ConversionResult + error: Optional[Exception] = None + is_failed: bool = False + + +@dataclass +class ProcessingResult: + """Aggregated outcome of a pipeline run.""" + + pages: List[Page] = field(default_factory=list) + failed_pages: List[Tuple[int, Exception]] = field(default_factory=list) + total_expected: int = 0 + + @property + def success_count(self) -> int: + return len(self.pages) + + @property + def failure_count(self) -> int: + return len(self.failed_pages) + + @property + def is_partial_success(self) -> bool: + return 0 < self.success_count < self.total_expected + + @property + def is_complete_failure(self) -> bool: + return self.success_count == 0 and self.failure_count > 0 + + +class ThreadedQueue: + """Bounded queue with blocking put/ get_batch and explicit *close()* semantics.""" + + __slots__ = ("_closed", "_items", "_lock", "_max", "_not_empty", "_not_full") + + def __init__(self, max_size: int) -> None: + self._max: int = max_size + self._items: deque[ThreadedItem] = deque() + self._lock = threading.Lock() + self._not_full = threading.Condition(self._lock) + self._not_empty = threading.Condition(self._lock) + self._closed = False + + # ---------------------------------------------------------------- put() + def put(self, item: ThreadedItem, timeout: Optional[float] | None = None) -> bool: + """Block until queue accepts *item* or is closed. Returns *False* if closed.""" + with self._not_full: + if self._closed: + return False + start = time.monotonic() + while len(self._items) >= self._max and not self._closed: + if timeout is not None: + remaining = timeout - (time.monotonic() - start) + if remaining <= 0: + return False + self._not_full.wait(remaining) + else: + self._not_full.wait() + if self._closed: + return False + self._items.append(item) + self._not_empty.notify() + return True + + # ------------------------------------------------------------ get_batch() + def get_batch( + self, size: int, timeout: Optional[float] | None = None + ) -> List[ThreadedItem]: + """Return up to *size* items. Blocks until ≥1 item present or queue closed/timeout.""" + with self._not_empty: + start = time.monotonic() + while not self._items and not self._closed: + if timeout is not None: + remaining = timeout - (time.monotonic() - start) + if remaining <= 0: + return [] + self._not_empty.wait(remaining) + else: + self._not_empty.wait() + batch: List[ThreadedItem] = [] + while self._items and len(batch) < size: + batch.append(self._items.popleft()) + if batch: + self._not_full.notify_all() + return batch + + # ---------------------------------------------------------------- close() + def close(self) -> None: + with self._lock: + self._closed = True + self._not_empty.notify_all() + self._not_full.notify_all() + + # -------------------------------------------------------------- property + @property + def closed(self) -> bool: + return self._closed + + +class ThreadedPipelineStage: + """A single pipeline stage backed by one worker thread.""" + + def __init__( + self, + *, + name: str, + model: Any, + batch_size: int, + batch_timeout: float, + queue_max_size: int, + postprocess: Optional[Callable[[ThreadedItem], None]] = None, + ) -> None: + self.name = name + self.model = model + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.input_queue = ThreadedQueue(queue_max_size) + self._outputs: list[ThreadedQueue] = [] + self._thread: Optional[threading.Thread] = None + self._running = False + self._postprocess = postprocess + + # ---------------------------------------------------------------- wiring + def add_output_queue(self, q: ThreadedQueue) -> None: + self._outputs.append(q) + + # -------------------------------------------------------------- lifecycle + def start(self) -> None: + if self._running: + return + self._running = True + self._thread = threading.Thread( + target=self._run, name=f"Stage-{self.name}", daemon=True + ) + self._thread.start() + + def stop(self) -> None: + if not self._running: + return + self._running = False + self.input_queue.close() + if self._thread is not None: + self._thread.join(timeout=30.0) + if self._thread.is_alive(): + _log.warning("Stage %s did not terminate cleanly within 30s", self.name) + + # ------------------------------------------------------------------ _run + def _run(self) -> None: + try: + while self._running: + batch = self.input_queue.get_batch(self.batch_size, self.batch_timeout) + if not batch and self.input_queue.closed: + break + processed = self._process_batch(batch) + self._emit(processed) + except Exception: # pragma: no cover - top-level guard + _log.exception("Fatal error in stage %s", self.name) + finally: + for q in self._outputs: + q.close() + + # ----------------------------------------------------- _process_batch() + def _process_batch(self, batch: Sequence[ThreadedItem]) -> list[ThreadedItem]: + """Run *model* on *batch* grouped by run_id to maximise batching.""" + groups: dict[int, list[ThreadedItem]] = defaultdict(list) + for itm in batch: + groups[itm.run_id].append(itm) + + result: list[ThreadedItem] = [] + for rid, items in groups.items(): + good: list[ThreadedItem] = [i for i in items if not i.is_failed] + if not good: + result.extend(items) + continue + try: + # Filter out None payloads and ensure type safety + pages_with_payloads = [ + (i, i.payload) for i in good if i.payload is not None + ] + if len(pages_with_payloads) != len(good): + # Some items have None payloads, mark all as failed + for it in items: + it.is_failed = True + it.error = RuntimeError("Page payload is None") + result.extend(items) + continue + + pages: List[Page] = [payload for _, payload in pages_with_payloads] + processed_pages = list(self.model(good[0].conv_res, pages)) # type: ignore[arg-type] + if len(processed_pages) != len(pages): # strict mismatch guard + raise RuntimeError( + f"Model {self.name} returned wrong number of pages" + ) + for idx, page in enumerate(processed_pages): + result.append( + ThreadedItem( + payload=page, + run_id=rid, + page_no=good[idx].page_no, + conv_res=good[idx].conv_res, + ) + ) + except Exception as exc: + _log.error("Stage %s failed for run %d: %s", self.name, rid, exc) + for it in items: + it.is_failed = True + it.error = exc + result.extend(items) + return result + + # -------------------------------------------------------------- _emit() + def _emit(self, items: Iterable[ThreadedItem]) -> None: + for item in items: + if self._postprocess is not None: + self._postprocess(item) + for q in self._outputs: + if not q.put(item): + _log.error("Output queue closed while emitting from %s", self.name) + + +class PreprocessThreadedStage(ThreadedPipelineStage): + """Pipeline stage that lazily loads PDF backends just-in-time.""" + + def __init__( + self, + *, + batch_timeout: float, + queue_max_size: int, + model: Any, + ) -> None: + super().__init__( + name="preprocess", + model=model, + batch_size=1, + batch_timeout=batch_timeout, + queue_max_size=queue_max_size, + ) + + def _process_batch(self, batch: Sequence[ThreadedItem]) -> list[ThreadedItem]: + groups: dict[int, list[ThreadedItem]] = defaultdict(list) + for itm in batch: + groups[itm.run_id].append(itm) + + result: list[ThreadedItem] = [] + for rid, items in groups.items(): + good = [i for i in items if not i.is_failed] + if not good: + result.extend(items) + continue + try: + pages_with_payloads: list[tuple[ThreadedItem, Page]] = [] + for it in good: + page = it.payload + if page is None: + raise RuntimeError("Page payload is None") + if page._backend is None: + backend = it.conv_res.input._backend + assert isinstance(backend, PdfDocumentBackend), ( + "Threaded pipeline only supports PdfDocumentBackend." + ) + page_backend = backend.load_page(page.page_no) + page._backend = page_backend + if page_backend.is_valid(): + page.size = page_backend.get_size() + pages_with_payloads.append((it, page)) + + pages = [payload for _, payload in pages_with_payloads] + processed_pages = list( + self.model(good[0].conv_res, pages) # type: ignore[arg-type] + ) + if len(processed_pages) != len(pages): + raise RuntimeError( + "PagePreprocessingModel returned unexpected number of pages" + ) + for idx, processed_page in enumerate(processed_pages): + result.append( + ThreadedItem( + payload=processed_page, + run_id=rid, + page_no=good[idx].page_no, + conv_res=good[idx].conv_res, + ) + ) + except Exception as exc: + _log.error("Stage preprocess failed for run %d: %s", rid, exc) + for it in items: + it.is_failed = True + it.error = exc + result.extend(items) + return result + + +@dataclass +class RunContext: + """Wiring for a single *execute* call.""" + + stages: list[ThreadedPipelineStage] + first_stage: ThreadedPipelineStage + output_queue: ThreadedQueue + + +# ────────────────────────────────────────────────────────────────────────────── +# Main pipeline +# ────────────────────────────────────────────────────────────────────────────── + + +class StandardPdfPipeline(ConvertPipeline): + """High-performance PDF pipeline with multi-threaded stages.""" + + def __init__(self, pipeline_options: ThreadedPdfPipelineOptions) -> None: super().__init__(pipeline_options) - self.pipeline_options: PdfPipelineOptions + self.pipeline_options: ThreadedPdfPipelineOptions = pipeline_options + self._run_seq = itertools.count(1) # deterministic, monotonic run ids - with warnings.catch_warnings(): # deprecated generate_table_images - warnings.filterwarnings("ignore", category=DeprecationWarning) - self.keep_images = ( - self.pipeline_options.generate_page_images - or self.pipeline_options.generate_picture_images - or self.pipeline_options.generate_table_images + # initialise heavy models once + self._init_models() + + # ──────────────────────────────────────────────────────────────────────── + # Heavy-model initialisation & helpers + # ──────────────────────────────────────────────────────────────────────── + + def _init_models(self) -> None: + art_path = self.artifacts_path + self.keep_images = ( + self.pipeline_options.generate_page_images + or self.pipeline_options.generate_picture_images + or self.pipeline_options.generate_table_images + ) + self.preprocessing_model = PagePreprocessingModel( + options=PagePreprocessingOptions( + images_scale=self.pipeline_options.images_scale ) - + ) + self.ocr_model = self._make_ocr_model(art_path) + self.layout_model = LayoutModel( + artifacts_path=art_path, + accelerator_options=self.pipeline_options.accelerator_options, + options=self.pipeline_options.layout_options, + ) + self.table_model = TableStructureModel( + enabled=self.pipeline_options.do_table_structure, + artifacts_path=art_path, + options=self.pipeline_options.table_structure_options, + accelerator_options=self.pipeline_options.accelerator_options, + ) + self.assemble_model = PageAssembleModel(options=PageAssembleOptions()) self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions()) - ocr_model = self.get_ocr_model(artifacts_path=self.artifacts_path) - - self.build_pipe = [ - # Pre-processing - PagePreprocessingModel( - options=PagePreprocessingOptions( - images_scale=pipeline_options.images_scale, - ) - ), - # OCR - ocr_model, - # Layout model - LayoutModel( - artifacts_path=self.artifacts_path, - accelerator_options=pipeline_options.accelerator_options, - options=pipeline_options.layout_options, - ), - # Table structure model - TableStructureModel( - enabled=pipeline_options.do_table_structure, - artifacts_path=self.artifacts_path, - options=pipeline_options.table_structure_options, - accelerator_options=pipeline_options.accelerator_options, - ), - # Page assemble - PageAssembleModel(options=PageAssembleOptions()), - ] - + # --- optional enrichment ------------------------------------------------ self.enrichment_pipe = [ # Code Formula Enrichment Model CodeFormulaModel( - enabled=pipeline_options.do_code_enrichment - or pipeline_options.do_formula_enrichment, + enabled=self.pipeline_options.do_code_enrichment + or self.pipeline_options.do_formula_enrichment, artifacts_path=self.artifacts_path, options=CodeFormulaModelOptions( - do_code_enrichment=pipeline_options.do_code_enrichment, - do_formula_enrichment=pipeline_options.do_formula_enrichment, + do_code_enrichment=self.pipeline_options.do_code_enrichment, + do_formula_enrichment=self.pipeline_options.do_formula_enrichment, ), - accelerator_options=pipeline_options.accelerator_options, + accelerator_options=self.pipeline_options.accelerator_options, ), *self.enrichment_pipe, ] - if ( - self.pipeline_options.do_formula_enrichment - or self.pipeline_options.do_code_enrichment - or self.pipeline_options.do_picture_classification - or self.pipeline_options.do_picture_description - ): - self.keep_backend = True - - @staticmethod - def download_models_hf( - local_dir: Optional[Path] = None, force: bool = False - ) -> Path: - warnings.warn( - "The usage of StandardPdfPipeline.download_models_hf() is deprecated " - "use instead the utility `docling-tools models download`, or " - "the upstream method docling.utils.models_downloader.download_all()", - DeprecationWarning, - stacklevel=3, + self.keep_backend = any( + ( + self.pipeline_options.do_formula_enrichment, + self.pipeline_options.do_code_enrichment, + self.pipeline_options.do_picture_classification, + self.pipeline_options.do_picture_description, + ) ) - output_dir = download_models(output_dir=local_dir, force=force, progress=False) - return output_dir - - def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: + # ---------------------------------------------------------------- helpers + def _make_ocr_model(self, art_path: Optional[Path]) -> Any: factory = get_ocr_factory( allow_external_plugins=self.pipeline_options.allow_external_plugins ) return factory.create_instance( options=self.pipeline_options.ocr_options, enabled=self.pipeline_options.do_ocr, - artifacts_path=artifacts_path, + artifacts_path=art_path, accelerator_options=self.pipeline_options.accelerator_options, ) - def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: - with TimeRecorder(conv_res, "page_init"): - page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore - if page._backend is not None and page._backend.is_valid(): - page.size = page._backend.get_size() + def _release_page_resources(self, item: ThreadedItem) -> None: + page = item.payload + if page is None: + return + if not self.keep_images: + page._image_cache = {} + if not self.keep_backend and page._backend is not None: + page._backend.unload() + page._backend = None + if not self.pipeline_options.generate_parsed_pages: + page.parsed_page = None - return page + # ──────────────────────────────────────────────────────────────────────── + # Build - thread pipeline + # ──────────────────────────────────────────────────────────────────────── + def _create_run_ctx(self) -> RunContext: + opts = self.pipeline_options + preprocess = PreprocessThreadedStage( + batch_timeout=opts.batch_polling_interval_seconds, + queue_max_size=opts.queue_max_size, + model=self.preprocessing_model, + ) + ocr = ThreadedPipelineStage( + name="ocr", + model=self.ocr_model, + batch_size=opts.ocr_batch_size, + batch_timeout=opts.batch_polling_interval_seconds, + queue_max_size=opts.queue_max_size, + ) + layout = ThreadedPipelineStage( + name="layout", + model=self.layout_model, + batch_size=opts.layout_batch_size, + batch_timeout=opts.batch_polling_interval_seconds, + queue_max_size=opts.queue_max_size, + ) + table = ThreadedPipelineStage( + name="table", + model=self.table_model, + batch_size=opts.table_batch_size, + batch_timeout=opts.batch_polling_interval_seconds, + queue_max_size=opts.queue_max_size, + ) + assemble = ThreadedPipelineStage( + name="assemble", + model=self.assemble_model, + batch_size=1, + batch_timeout=opts.batch_polling_interval_seconds, + queue_max_size=opts.queue_max_size, + postprocess=self._release_page_resources, + ) + + # wire stages + output_q = ThreadedQueue(opts.queue_max_size) + preprocess.add_output_queue(ocr.input_queue) + ocr.add_output_queue(layout.input_queue) + layout.add_output_queue(table.input_queue) + table.add_output_queue(assemble.input_queue) + assemble.add_output_queue(output_q) + + stages = [preprocess, ocr, layout, table, assemble] + return RunContext(stages=stages, first_stage=preprocess, output_queue=output_q) + + # --------------------------------------------------------------------- build + def _build_document(self, conv_res: ConversionResult) -> ConversionResult: + """Stream-build the document while interleaving producer and consumer work.""" + run_id = next(self._run_seq) + assert isinstance(conv_res.input._backend, PdfDocumentBackend) + + # Collect page placeholders; backends are loaded lazily in preprocess stage + start_page, end_page = conv_res.input.limits.page_range + pages: list[Page] = [] + for i in range(conv_res.input.page_count): + if start_page - 1 <= i <= end_page - 1: + page = Page(page_no=i) + conv_res.pages.append(page) + pages.append(page) + + if not pages: + conv_res.status = ConversionStatus.FAILURE + return conv_res + + total_pages: int = len(pages) + ctx: RunContext = self._create_run_ctx() + for st in ctx.stages: + st.start() + + proc = ProcessingResult(total_expected=total_pages) + fed_idx: int = 0 # number of pages successfully queued + batch_size: int = 32 # drain chunk + try: + while proc.success_count + proc.failure_count < total_pages: + # 1) feed - try to enqueue until the first queue is full + while fed_idx < total_pages: + ok = ctx.first_stage.input_queue.put( + ThreadedItem( + payload=pages[fed_idx], + run_id=run_id, + page_no=pages[fed_idx].page_no, + conv_res=conv_res, + ), + timeout=0.0, # non-blocking try-put + ) + if ok: + fed_idx += 1 + if fed_idx == total_pages: + ctx.first_stage.input_queue.close() + else: # queue full - switch to draining + break + + # 2) drain - pull whatever is ready from the output side + out_batch = ctx.output_queue.get_batch(batch_size, timeout=0.05) + for itm in out_batch: + if itm.run_id != run_id: + continue + if itm.is_failed or itm.error: + proc.failed_pages.append( + (itm.page_no, itm.error or RuntimeError("unknown error")) + ) + else: + assert itm.payload is not None + proc.pages.append(itm.payload) + + # 3) failure safety - downstream closed early -> mark missing pages failed + if not out_batch and ctx.output_queue.closed: + missing = total_pages - (proc.success_count + proc.failure_count) + if missing > 0: + proc.failed_pages.extend( + [(-1, RuntimeError("pipeline terminated early"))] * missing + ) + break + finally: + for st in ctx.stages: + st.stop() + ctx.output_queue.close() + + self._integrate_results(conv_res, proc) + return conv_res + + # ---------------------------------------------------- integrate_results() + def _integrate_results( + self, conv_res: ConversionResult, proc: ProcessingResult + ) -> None: + page_map = {p.page_no: p for p in proc.pages} + conv_res.pages = [ + page_map.get(p.page_no, p) + for p in conv_res.pages + if p.page_no in page_map + or not any(fp == p.page_no for fp, _ in proc.failed_pages) + ] + if proc.is_complete_failure: + conv_res.status = ConversionStatus.FAILURE + elif proc.is_partial_success: + conv_res.status = ConversionStatus.PARTIAL_SUCCESS + else: + conv_res.status = ConversionStatus.SUCCESS + if not self.keep_images: + for p in conv_res.pages: + p._image_cache = {} + for p in conv_res.pages: + if not self.keep_backend and p._backend is not None: + p._backend.unload() + if not self.pipeline_options.generate_parsed_pages: + del p.parsed_page + p.parsed_page = None + + # ---------------------------------------------------------------- assemble def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult: - all_elements = [] - all_headers = [] - all_body = [] - + elements, headers, body = [], [], [] with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT): for p in conv_res.pages: - if p.assembled is not None: - for el in p.assembled.body: - all_body.append(el) - for el in p.assembled.headers: - all_headers.append(el) - for el in p.assembled.elements: - all_elements.append(el) - + if p.assembled: + elements.extend(p.assembled.elements) + headers.extend(p.assembled.headers) + body.extend(p.assembled.body) conv_res.assembled = AssembledUnit( - elements=all_elements, headers=all_headers, body=all_body + elements=elements, headers=headers, body=body ) - conv_res.document = self.reading_order_model(conv_res) # Generate page images in the output @@ -233,10 +709,21 @@ class StandardPdfPipeline(PaginatedPipeline): return conv_res + # ---------------------------------------------------------------- misc @classmethod - def get_default_options(cls) -> PdfPipelineOptions: - return PdfPipelineOptions() + def get_default_options(cls) -> ThreadedPdfPipelineOptions: + return ThreadedPdfPipelineOptions() @classmethod - def is_backend_supported(cls, backend: AbstractDocumentBackend): + def is_backend_supported(cls, backend: AbstractDocumentBackend) -> bool: return isinstance(backend, PdfDocumentBackend) + + def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: + return conv_res.status + + def _unload(self, conv_res: ConversionResult) -> None: + for p in conv_res.pages: + if p._backend is not None: + p._backend.unload() + if conv_res.input._backend: + conv_res.input._backend.unload() diff --git a/docling/pipeline/threaded_standard_pdf_pipeline.py b/docling/pipeline/threaded_standard_pdf_pipeline.py index a31270d0..21af9813 100644 --- a/docling/pipeline/threaded_standard_pdf_pipeline.py +++ b/docling/pipeline/threaded_standard_pdf_pipeline.py @@ -1,647 +1,5 @@ -# threaded_standard_pdf_pipeline.py -"""Thread-safe, production-ready PDF pipeline -================================================ -A self-contained, thread-safe PDF conversion pipeline exploiting parallelism between pipeline stages and models. +from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline -* **Per-run isolation** - every :py:meth:`execute` call uses its own bounded queues and worker - threads so that concurrent invocations never share mutable state. -* **Deterministic run identifiers** - pages are tracked with an internal *run-id* instead of - relying on :pyfunc:`id`, which may clash after garbage collection. -* **Explicit back-pressure & shutdown** - producers block on full queues; queue *close()* - propagates downstream so stages terminate deterministically without sentinels. -* **Minimal shared state** - heavyweight models are initialised once per pipeline instance - and only read by worker threads; no runtime mutability is exposed. -* **Strict typing & clean API usage** - code is fully annotated and respects *coding_rules.md*. -""" -from __future__ import annotations - -import itertools -import logging -import threading -import time -import warnings -from collections import defaultdict, deque -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Iterable, List, Optional, Sequence, Tuple, cast - -import numpy as np -from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem - -from docling.backend.abstract_backend import AbstractDocumentBackend -from docling.backend.pdf_backend import PdfDocumentBackend -from docling.datamodel.base_models import AssembledUnit, ConversionStatus, Page -from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import ThreadedPdfPipelineOptions -from docling.datamodel.settings import settings -from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions -from docling.models.factories import get_ocr_factory -from docling.models.layout_model import LayoutModel -from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions -from docling.models.page_preprocessing_model import ( - PagePreprocessingModel, - PagePreprocessingOptions, -) -from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions -from docling.models.table_structure_model import TableStructureModel -from docling.pipeline.base_pipeline import ConvertPipeline -from docling.utils.profiling import ProfilingScope, TimeRecorder -from docling.utils.utils import chunkify - -_log = logging.getLogger(__name__) - -# ────────────────────────────────────────────────────────────────────────────── -# Helper data structures -# ────────────────────────────────────────────────────────────────────────────── - - -@dataclass -class ThreadedItem: - """Envelope that travels between pipeline stages.""" - - payload: Optional[Page] - run_id: int # Unique per *execute* call, monotonic across pipeline instance - page_no: int - conv_res: ConversionResult - error: Optional[Exception] = None - is_failed: bool = False - - -@dataclass -class ProcessingResult: - """Aggregated outcome of a pipeline run.""" - - pages: List[Page] = field(default_factory=list) - failed_pages: List[Tuple[int, Exception]] = field(default_factory=list) - total_expected: int = 0 - - @property - def success_count(self) -> int: - return len(self.pages) - - @property - def failure_count(self) -> int: - return len(self.failed_pages) - - @property - def is_partial_success(self) -> bool: - return 0 < self.success_count < self.total_expected - - @property - def is_complete_failure(self) -> bool: - return self.success_count == 0 and self.failure_count > 0 - - -class ThreadedQueue: - """Bounded queue with blocking put/ get_batch and explicit *close()* semantics.""" - - __slots__ = ("_closed", "_items", "_lock", "_max", "_not_empty", "_not_full") - - def __init__(self, max_size: int) -> None: - self._max: int = max_size - self._items: deque[ThreadedItem] = deque() - self._lock = threading.Lock() - self._not_full = threading.Condition(self._lock) - self._not_empty = threading.Condition(self._lock) - self._closed = False - - # ---------------------------------------------------------------- put() - def put(self, item: ThreadedItem, timeout: Optional[float] | None = None) -> bool: - """Block until queue accepts *item* or is closed. Returns *False* if closed.""" - with self._not_full: - if self._closed: - return False - start = time.monotonic() - while len(self._items) >= self._max and not self._closed: - if timeout is not None: - remaining = timeout - (time.monotonic() - start) - if remaining <= 0: - return False - self._not_full.wait(remaining) - else: - self._not_full.wait() - if self._closed: - return False - self._items.append(item) - self._not_empty.notify() - return True - - # ------------------------------------------------------------ get_batch() - def get_batch( - self, size: int, timeout: Optional[float] | None = None - ) -> List[ThreadedItem]: - """Return up to *size* items. Blocks until ≥1 item present or queue closed/timeout.""" - with self._not_empty: - start = time.monotonic() - while not self._items and not self._closed: - if timeout is not None: - remaining = timeout - (time.monotonic() - start) - if remaining <= 0: - return [] - self._not_empty.wait(remaining) - else: - self._not_empty.wait() - batch: List[ThreadedItem] = [] - while self._items and len(batch) < size: - batch.append(self._items.popleft()) - if batch: - self._not_full.notify_all() - return batch - - # ---------------------------------------------------------------- close() - def close(self) -> None: - with self._lock: - self._closed = True - self._not_empty.notify_all() - self._not_full.notify_all() - - # -------------------------------------------------------------- property - @property - def closed(self) -> bool: - return self._closed - - -class ThreadedPipelineStage: - """A single pipeline stage backed by one worker thread.""" - - def __init__( - self, - *, - name: str, - model: Any, - batch_size: int, - batch_timeout: float, - queue_max_size: int, - ) -> None: - self.name = name - self.model = model - self.batch_size = batch_size - self.batch_timeout = batch_timeout - self.input_queue = ThreadedQueue(queue_max_size) - self._outputs: list[ThreadedQueue] = [] - self._thread: Optional[threading.Thread] = None - self._running = False - - # ---------------------------------------------------------------- wiring - def add_output_queue(self, q: ThreadedQueue) -> None: - self._outputs.append(q) - - # -------------------------------------------------------------- lifecycle - def start(self) -> None: - if self._running: - return - self._running = True - self._thread = threading.Thread( - target=self._run, name=f"Stage-{self.name}", daemon=True - ) - self._thread.start() - - def stop(self) -> None: - if not self._running: - return - self._running = False - self.input_queue.close() - if self._thread is not None: - self._thread.join(timeout=30.0) - if self._thread.is_alive(): - _log.warning("Stage %s did not terminate cleanly within 30s", self.name) - - # ------------------------------------------------------------------ _run - def _run(self) -> None: - try: - while self._running: - batch = self.input_queue.get_batch(self.batch_size, self.batch_timeout) - if not batch and self.input_queue.closed: - break - processed = self._process_batch(batch) - self._emit(processed) - except Exception: # pragma: no cover - top-level guard - _log.exception("Fatal error in stage %s", self.name) - finally: - for q in self._outputs: - q.close() - - # ----------------------------------------------------- _process_batch() - def _process_batch(self, batch: Sequence[ThreadedItem]) -> list[ThreadedItem]: - """Run *model* on *batch* grouped by run_id to maximise batching.""" - groups: dict[int, list[ThreadedItem]] = defaultdict(list) - for itm in batch: - groups[itm.run_id].append(itm) - - result: list[ThreadedItem] = [] - for rid, items in groups.items(): - good: list[ThreadedItem] = [i for i in items if not i.is_failed] - if not good: - result.extend(items) - continue - try: - # Filter out None payloads and ensure type safety - pages_with_payloads = [ - (i, i.payload) for i in good if i.payload is not None - ] - if len(pages_with_payloads) != len(good): - # Some items have None payloads, mark all as failed - for it in items: - it.is_failed = True - it.error = RuntimeError("Page payload is None") - result.extend(items) - continue - - pages: List[Page] = [payload for _, payload in pages_with_payloads] - processed_pages = list(self.model(good[0].conv_res, pages)) # type: ignore[arg-type] - if len(processed_pages) != len(pages): # strict mismatch guard - raise RuntimeError( - f"Model {self.name} returned wrong number of pages" - ) - for idx, page in enumerate(processed_pages): - result.append( - ThreadedItem( - payload=page, - run_id=rid, - page_no=good[idx].page_no, - conv_res=good[idx].conv_res, - ) - ) - except Exception as exc: - _log.error("Stage %s failed for run %d: %s", self.name, rid, exc) - for it in items: - it.is_failed = True - it.error = exc - result.extend(items) - return result - - # -------------------------------------------------------------- _emit() - def _emit(self, items: Iterable[ThreadedItem]) -> None: - for item in items: - for q in self._outputs: - if not q.put(item): - _log.error("Output queue closed while emitting from %s", self.name) - - -@dataclass -class RunContext: - """Wiring for a single *execute* call.""" - - stages: list[ThreadedPipelineStage] - first_stage: ThreadedPipelineStage - output_queue: ThreadedQueue - - -# ────────────────────────────────────────────────────────────────────────────── -# Main pipeline -# ────────────────────────────────────────────────────────────────────────────── - - -class ThreadedStandardPdfPipeline(ConvertPipeline): - """High-performance PDF pipeline with multi-threaded stages.""" - - def __init__(self, pipeline_options: ThreadedPdfPipelineOptions) -> None: - super().__init__(pipeline_options) - self.pipeline_options: ThreadedPdfPipelineOptions = pipeline_options - self._run_seq = itertools.count(1) # deterministic, monotonic run ids - - # initialise heavy models once - self._init_models() - - # ──────────────────────────────────────────────────────────────────────── - # Heavy-model initialisation & helpers - # ──────────────────────────────────────────────────────────────────────── - - def _init_models(self) -> None: - art_path = self.artifacts_path - self.keep_images = ( - self.pipeline_options.generate_page_images - or self.pipeline_options.generate_picture_images - or self.pipeline_options.generate_table_images - ) - self.preprocessing_model = PagePreprocessingModel( - options=PagePreprocessingOptions( - images_scale=self.pipeline_options.images_scale - ) - ) - self.ocr_model = self._make_ocr_model(art_path) - self.layout_model = LayoutModel( - artifacts_path=art_path, - accelerator_options=self.pipeline_options.accelerator_options, - options=self.pipeline_options.layout_options, - ) - self.table_model = TableStructureModel( - enabled=self.pipeline_options.do_table_structure, - artifacts_path=art_path, - options=self.pipeline_options.table_structure_options, - accelerator_options=self.pipeline_options.accelerator_options, - ) - self.assemble_model = PageAssembleModel(options=PageAssembleOptions()) - self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions()) - - # --- optional enrichment ------------------------------------------------ - self.enrichment_pipe = [ - # Code Formula Enrichment Model - CodeFormulaModel( - enabled=self.pipeline_options.do_code_enrichment - or self.pipeline_options.do_formula_enrichment, - artifacts_path=self.artifacts_path, - options=CodeFormulaModelOptions( - do_code_enrichment=self.pipeline_options.do_code_enrichment, - do_formula_enrichment=self.pipeline_options.do_formula_enrichment, - ), - accelerator_options=self.pipeline_options.accelerator_options, - ), - *self.enrichment_pipe, - ] - - self.keep_backend = any( - ( - self.pipeline_options.do_formula_enrichment, - self.pipeline_options.do_code_enrichment, - self.pipeline_options.do_picture_classification, - self.pipeline_options.do_picture_description, - ) - ) - - # ---------------------------------------------------------------- helpers - def _make_ocr_model(self, art_path: Optional[Path]) -> Any: - factory = get_ocr_factory( - allow_external_plugins=self.pipeline_options.allow_external_plugins - ) - return factory.create_instance( - options=self.pipeline_options.ocr_options, - enabled=self.pipeline_options.do_ocr, - artifacts_path=art_path, - accelerator_options=self.pipeline_options.accelerator_options, - ) - - # ──────────────────────────────────────────────────────────────────────── - # Build - thread pipeline - # ──────────────────────────────────────────────────────────────────────── - - def _create_run_ctx(self) -> RunContext: - opts = self.pipeline_options - preprocess = ThreadedPipelineStage( - name="preprocess", - model=self.preprocessing_model, - batch_size=1, - batch_timeout=opts.batch_timeout_seconds, - queue_max_size=opts.queue_max_size, - ) - ocr = ThreadedPipelineStage( - name="ocr", - model=self.ocr_model, - batch_size=opts.ocr_batch_size, - batch_timeout=opts.batch_timeout_seconds, - queue_max_size=opts.queue_max_size, - ) - layout = ThreadedPipelineStage( - name="layout", - model=self.layout_model, - batch_size=opts.layout_batch_size, - batch_timeout=opts.batch_timeout_seconds, - queue_max_size=opts.queue_max_size, - ) - table = ThreadedPipelineStage( - name="table", - model=self.table_model, - batch_size=opts.table_batch_size, - batch_timeout=opts.batch_timeout_seconds, - queue_max_size=opts.queue_max_size, - ) - assemble = ThreadedPipelineStage( - name="assemble", - model=self.assemble_model, - batch_size=1, - batch_timeout=opts.batch_timeout_seconds, - queue_max_size=opts.queue_max_size, - ) - - # wire stages - output_q = ThreadedQueue(opts.queue_max_size) - preprocess.add_output_queue(ocr.input_queue) - ocr.add_output_queue(layout.input_queue) - layout.add_output_queue(table.input_queue) - table.add_output_queue(assemble.input_queue) - assemble.add_output_queue(output_q) - - stages = [preprocess, ocr, layout, table, assemble] - return RunContext(stages=stages, first_stage=preprocess, output_queue=output_q) - - # --------------------------------------------------------------------- build - def _build_document(self, conv_res: ConversionResult) -> ConversionResult: - """Stream-build the document while interleaving producer and consumer work.""" - run_id = next(self._run_seq) - assert isinstance(conv_res.input._backend, PdfDocumentBackend) - backend = conv_res.input._backend - - # preload & initialise pages ------------------------------------------------------------- - start_page, end_page = conv_res.input.limits.page_range - pages: list[Page] = [] - for i in range(conv_res.input.page_count): - if start_page - 1 <= i <= end_page - 1: - page = Page(page_no=i) - page._backend = backend.load_page(i) - if page._backend and page._backend.is_valid(): - page.size = page._backend.get_size() - conv_res.pages.append(page) - pages.append(page) - - if not pages: - conv_res.status = ConversionStatus.FAILURE - return conv_res - - total_pages: int = len(pages) - ctx: RunContext = self._create_run_ctx() - for st in ctx.stages: - st.start() - - proc = ProcessingResult(total_expected=total_pages) - fed_idx: int = 0 # number of pages successfully queued - batch_size: int = 32 # drain chunk - try: - while proc.success_count + proc.failure_count < total_pages: - # 1) feed - try to enqueue until the first queue is full - while fed_idx < total_pages: - ok = ctx.first_stage.input_queue.put( - ThreadedItem( - payload=pages[fed_idx], - run_id=run_id, - page_no=pages[fed_idx].page_no, - conv_res=conv_res, - ), - timeout=0.0, # non-blocking try-put - ) - if ok: - fed_idx += 1 - if fed_idx == total_pages: - ctx.first_stage.input_queue.close() - else: # queue full - switch to draining - break - - # 2) drain - pull whatever is ready from the output side - out_batch = ctx.output_queue.get_batch(batch_size, timeout=0.05) - for itm in out_batch: - if itm.run_id != run_id: - continue - if itm.is_failed or itm.error: - proc.failed_pages.append( - (itm.page_no, itm.error or RuntimeError("unknown error")) - ) - else: - assert itm.payload is not None - proc.pages.append(itm.payload) - - # 3) failure safety - downstream closed early -> mark missing pages failed - if not out_batch and ctx.output_queue.closed: - missing = total_pages - (proc.success_count + proc.failure_count) - if missing > 0: - proc.failed_pages.extend( - [(-1, RuntimeError("pipeline terminated early"))] * missing - ) - break - finally: - for st in ctx.stages: - st.stop() - ctx.output_queue.close() - - self._integrate_results(conv_res, proc) - return conv_res - - # ---------------------------------------------------- integrate_results() - def _integrate_results( - self, conv_res: ConversionResult, proc: ProcessingResult - ) -> None: - page_map = {p.page_no: p for p in proc.pages} - conv_res.pages = [ - page_map.get(p.page_no, p) - for p in conv_res.pages - if p.page_no in page_map - or not any(fp == p.page_no for fp, _ in proc.failed_pages) - ] - if proc.is_complete_failure: - conv_res.status = ConversionStatus.FAILURE - elif proc.is_partial_success: - conv_res.status = ConversionStatus.PARTIAL_SUCCESS - else: - conv_res.status = ConversionStatus.SUCCESS - if not self.keep_images: - for p in conv_res.pages: - p._image_cache = {} - for p in conv_res.pages: - if not self.keep_backend and p._backend is not None: - p._backend.unload() - if not self.pipeline_options.generate_parsed_pages: - del p.parsed_page - p.parsed_page = None - - # ---------------------------------------------------------------- assemble - def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult: - elements, headers, body = [], [], [] - with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT): - for p in conv_res.pages: - if p.assembled: - elements.extend(p.assembled.elements) - headers.extend(p.assembled.headers) - body.extend(p.assembled.body) - conv_res.assembled = AssembledUnit( - elements=elements, headers=headers, body=body - ) - conv_res.document = self.reading_order_model(conv_res) - - # Generate page images in the output - if self.pipeline_options.generate_page_images: - for page in conv_res.pages: - assert page.image is not None - page_no = page.page_no + 1 - conv_res.document.pages[page_no].image = ImageRef.from_pil( - page.image, dpi=int(72 * self.pipeline_options.images_scale) - ) - - # Generate images of the requested element types - with warnings.catch_warnings(): # deprecated generate_table_images - warnings.filterwarnings("ignore", category=DeprecationWarning) - if ( - self.pipeline_options.generate_picture_images - or self.pipeline_options.generate_table_images - ): - scale = self.pipeline_options.images_scale - for element, _level in conv_res.document.iterate_items(): - if not isinstance(element, DocItem) or len(element.prov) == 0: - continue - if ( - isinstance(element, PictureItem) - and self.pipeline_options.generate_picture_images - ) or ( - isinstance(element, TableItem) - and self.pipeline_options.generate_table_images - ): - page_ix = element.prov[0].page_no - 1 - page = next( - (p for p in conv_res.pages if p.page_no == page_ix), - cast("Page", None), - ) - assert page is not None - assert page.size is not None - assert page.image is not None - - crop_bbox = ( - element.prov[0] - .bbox.scaled(scale=scale) - .to_top_left_origin( - page_height=page.size.height * scale - ) - ) - - cropped_im = page.image.crop(crop_bbox.as_tuple()) - element.image = ImageRef.from_pil( - cropped_im, dpi=int(72 * scale) - ) - - # Aggregate confidence values for document: - if len(conv_res.pages) > 0: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - category=RuntimeWarning, - message="Mean of empty slice|All-NaN slice encountered", - ) - conv_res.confidence.layout_score = float( - np.nanmean( - [c.layout_score for c in conv_res.confidence.pages.values()] - ) - ) - conv_res.confidence.parse_score = float( - np.nanquantile( - [c.parse_score for c in conv_res.confidence.pages.values()], - q=0.1, # parse score should relate to worst 10% of pages. - ) - ) - conv_res.confidence.table_score = float( - np.nanmean( - [c.table_score for c in conv_res.confidence.pages.values()] - ) - ) - conv_res.confidence.ocr_score = float( - np.nanmean( - [c.ocr_score for c in conv_res.confidence.pages.values()] - ) - ) - - return conv_res - - # ---------------------------------------------------------------- misc - @classmethod - def get_default_options(cls) -> ThreadedPdfPipelineOptions: - return ThreadedPdfPipelineOptions() - - @classmethod - def is_backend_supported(cls, backend: AbstractDocumentBackend) -> bool: - return isinstance(backend, PdfDocumentBackend) - - def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: - return conv_res.status - - def _unload(self, conv_res: ConversionResult) -> None: - for p in conv_res.pages: - if p._backend is not None: - p._backend.unload() - if conv_res.input._backend: - conv_res.input._backend.unload() +class ThreadedStandardPdfPipeline(StandardPdfPipeline): + """Backwards compatible import for ThreadedStandardPdfPipeline.""" diff --git a/tests/test_threaded_pipeline.py b/tests/test_threaded_pipeline.py index 5810565c..c24716cd 100644 --- a/tests/test_threaded_pipeline.py +++ b/tests/test_threaded_pipeline.py @@ -42,7 +42,7 @@ def test_threaded_pipeline_multiple_documents(): layout_batch_size=1, table_batch_size=1, ocr_batch_size=1, - batch_timeout_seconds=1.0, + batch_polling_interval_seconds=1.0, do_table_structure=do_ts, do_ocr=do_ocr, ), diff --git a/uv.lock b/uv.lock index 6797aed6..20951526 100644 --- a/uv.lock +++ b/uv.lock @@ -5943,6 +5943,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/8a/b35a615ae6f04550d696bb179c414538b3b477999435fdd4ad75b76139e4/pybase64-1.4.2-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:a370dea7b1cee2a36a4d5445d4e09cc243816c5bc8def61f602db5a6f5438e52", size = 54320, upload-time = "2025-07-27T13:03:27.495Z" }, { url = "https://files.pythonhosted.org/packages/d3/a9/8bd4f9bcc53689f1b457ecefed1eaa080e4949d65a62c31a38b7253d5226/pybase64-1.4.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9aa4de83f02e462a6f4e066811c71d6af31b52d7484de635582d0e3ec3d6cc3e", size = 56482, upload-time = "2025-07-27T13:03:28.942Z" }, { url = "https://files.pythonhosted.org/packages/75/e5/4a7735b54a1191f61c3f5c2952212c85c2d6b06eb5fb3671c7603395f70c/pybase64-1.4.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83a1c2f9ed00fee8f064d548c8654a480741131f280e5750bb32475b7ec8ee38", size = 70959, upload-time = "2025-07-27T13:03:30.171Z" }, + { url = "https://files.pythonhosted.org/packages/f4/56/5337f27a8b8d2d6693f46f7b36bae47895e5820bfa259b0072574a4e1057/pybase64-1.4.2-cp313-cp313-android_21_arm64_v8a.whl", hash = "sha256:0f331aa59549de21f690b6ccc79360ffed1155c3cfbc852eb5c097c0b8565a2b", size = 33888, upload-time = "2025-07-27T13:03:35.698Z" }, + { url = "https://files.pythonhosted.org/packages/e3/ff/470768f0fe6de0aa302a8cb1bdf2f9f5cffc3f69e60466153be68bc953aa/pybase64-1.4.2-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:69d3f0445b0faeef7bb7f93bf8c18d850785e2a77f12835f49e524cc54af04e7", size = 30914, upload-time = "2025-07-27T13:03:38.475Z" }, + { url = "https://files.pythonhosted.org/packages/75/6b/d328736662665e0892409dc410353ebef175b1be5eb6bab1dad579efa6df/pybase64-1.4.2-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:2372b257b1f4dd512f317fb27e77d313afd137334de64c87de8374027aacd88a", size = 31380, upload-time = "2025-07-27T13:03:39.7Z" }, { url = "https://files.pythonhosted.org/packages/ca/96/7ff718f87c67f4147c181b73d0928897cefa17dc75d7abc6e37730d5908f/pybase64-1.4.2-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:fb794502b4b1ec91c4ca5d283ae71aef65e3de7721057bd9e2b3ec79f7a62d7d", size = 38230, upload-time = "2025-07-27T13:03:41.637Z" }, { url = "https://files.pythonhosted.org/packages/71/ab/db4dbdfccb9ca874d6ce34a0784761471885d96730de85cee3d300381529/pybase64-1.4.2-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d377d48acf53abf4b926c2a7a24a19deb092f366a04ffd856bf4b3aa330b025d", size = 71608, upload-time = "2025-07-27T13:03:47.01Z" }, { url = "https://files.pythonhosted.org/packages/f2/58/7f2cef1ceccc682088958448d56727369de83fa6b29148478f4d2acd107a/pybase64-1.4.2-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:ab9cdb6a8176a5cb967f53e6ad60e40c83caaa1ae31c5e1b29e5c8f507f17538", size = 56413, upload-time = "2025-07-27T13:03:49.908Z" }, @@ -5964,6 +5967,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/f0/c392c4ac8ccb7a34b28377c21faa2395313e3c676d76c382642e19a20703/pybase64-1.4.2-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:ad59362fc267bf15498a318c9e076686e4beeb0dfe09b457fabbc2b32468b97a", size = 58103, upload-time = "2025-07-27T13:04:29.996Z" }, { url = "https://files.pythonhosted.org/packages/32/30/00ab21316e7df8f526aa3e3dc06f74de6711d51c65b020575d0105a025b2/pybase64-1.4.2-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:01593bd064e7dcd6c86d04e94e44acfe364049500c20ac68ca1e708fbb2ca970", size = 60779, upload-time = "2025-07-27T13:04:31.549Z" }, { url = "https://files.pythonhosted.org/packages/a6/65/114ca81839b1805ce4a2b7d58bc16e95634734a2059991f6382fc71caf3e/pybase64-1.4.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5b81547ad8ea271c79fdf10da89a1e9313cb15edcba2a17adf8871735e9c02a0", size = 74684, upload-time = "2025-07-27T13:04:32.976Z" }, + { url = "https://files.pythonhosted.org/packages/99/bf/00a87d951473ce96c8c08af22b6983e681bfabdb78dd2dcf7ee58eac0932/pybase64-1.4.2-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:4157ad277a32cf4f02a975dffc62a3c67d73dfa4609b2c1978ef47e722b18b8e", size = 30924, upload-time = "2025-07-27T13:04:39.189Z" }, + { url = "https://files.pythonhosted.org/packages/ae/43/dee58c9d60e60e6fb32dc6da722d84592e22f13c277297eb4ce6baf99a99/pybase64-1.4.2-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:e113267dc349cf624eb4f4fbf53fd77835e1aa048ac6877399af426aab435757", size = 31390, upload-time = "2025-07-27T13:04:40.995Z" }, { url = "https://files.pythonhosted.org/packages/e1/11/b28906fc2e330b8b1ab4bc845a7bef808b8506734e90ed79c6062b095112/pybase64-1.4.2-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:cea5aaf218fd9c5c23afacfe86fd4464dfedc1a0316dd3b5b4075b068cc67df0", size = 38212, upload-time = "2025-07-27T13:04:42.729Z" }, { url = "https://files.pythonhosted.org/packages/e4/2e/851eb51284b97354ee5dfa1309624ab90920696e91a33cd85b13d20cc5c1/pybase64-1.4.2-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:a3e54dcf0d0305ec88473c9d0009f698cabf86f88a8a10090efeff2879c421bb", size = 71674, upload-time = "2025-07-27T13:04:49.294Z" }, { url = "https://files.pythonhosted.org/packages/a4/8e/3479266bc0e65f6cc48b3938d4a83bff045330649869d950a378f2ddece0/pybase64-1.4.2-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:753da25d4fd20be7bda2746f545935773beea12d5cb5ec56ec2d2960796477b1", size = 56461, upload-time = "2025-07-27T13:04:52.37Z" },