feat: Use threading in the standard pipeline and move old behavior to legacy (#2452)

* rename standard to legacy

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove old standard pipeline

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* move threaded to standard

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add backwards compatible threaded pipeline

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* Updates for threaded pipeline to lower memory requirements

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* updating deps seem to remove the corrupted double-linked list error

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* update pinning

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use main lock

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add more threadsafe blocks

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* rename batch_timeout_seconds

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Michele Dolfi
2025-10-31 14:42:11 +01:00
committed by GitHub
parent 01577e92d1
commit 268d027c8f
7 changed files with 851 additions and 752 deletions

View File

@@ -361,15 +361,7 @@ class PdfPipelineOptions(PaginatedPipelineOptions):
generate_parsed_pages: bool = False
class ProcessingPipeline(str, Enum):
STANDARD = "standard"
VLM = "vlm"
ASR = "asr"
class ThreadedPdfPipelineOptions(PdfPipelineOptions):
"""Pipeline options for the threaded PDF pipeline with batching and backpressure control"""
### Arguments for threaded PDF pipeline with batching and backpressure control
# Batch sizes for different stages
ocr_batch_size: int = 4
@@ -377,7 +369,18 @@ class ThreadedPdfPipelineOptions(PdfPipelineOptions):
table_batch_size: int = 4
# Timing control
batch_timeout_seconds: float = 2.0
batch_polling_interval_seconds: float = 0.5
# Backpressure and queue control
queue_max_size: int = 100
class ProcessingPipeline(str, Enum):
LEGACY = "legacy"
STANDARD = "standard"
VLM = "vlm"
ASR = "asr"
class ThreadedPdfPipelineOptions(PdfPipelineOptions):
"""Pipeline options for the threaded PDF pipeline with batching and backpressure control"""

View File

@@ -167,6 +167,10 @@ class LayoutModel(BasePageModel):
valid_pages.append(page)
valid_page_images.append(page_image)
print(f"{len(pages)=}, {pages[0].page_no}-{pages[-1].page_no}")
print(f"{len(valid_pages)=}")
print(f"{len(valid_page_images)=}")
# Process all valid pages with batch prediction
batch_predictions = []
if valid_page_images:

View File

@@ -0,0 +1,242 @@
import logging
import warnings
from pathlib import Path
from typing import Optional, cast
import numpy as np
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import AssembledUnit, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.layout_model_specs import LayoutModelConfig
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.factories import get_ocr_factory
from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.model_downloader import download_models
from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__)
class LegacyStandardPdfPipeline(PaginatedPipeline):
def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions
with warnings.catch_warnings(): # deprecated generate_table_images
warnings.filterwarnings("ignore", category=DeprecationWarning)
self.keep_images = (
self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
)
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
ocr_model = self.get_ocr_model(artifacts_path=self.artifacts_path)
self.build_pipe = [
# Pre-processing
PagePreprocessingModel(
options=PagePreprocessingOptions(
images_scale=pipeline_options.images_scale,
)
),
# OCR
ocr_model,
# Layout model
LayoutModel(
artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
options=pipeline_options.layout_options,
),
# Table structure model
TableStructureModel(
enabled=pipeline_options.do_table_structure,
artifacts_path=self.artifacts_path,
options=pipeline_options.table_structure_options,
accelerator_options=pipeline_options.accelerator_options,
),
# Page assemble
PageAssembleModel(options=PageAssembleOptions()),
]
self.enrichment_pipe = [
# Code Formula Enrichment Model
CodeFormulaModel(
enabled=pipeline_options.do_code_enrichment
or pipeline_options.do_formula_enrichment,
artifacts_path=self.artifacts_path,
options=CodeFormulaModelOptions(
do_code_enrichment=pipeline_options.do_code_enrichment,
do_formula_enrichment=pipeline_options.do_formula_enrichment,
),
accelerator_options=pipeline_options.accelerator_options,
),
*self.enrichment_pipe,
]
if (
self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_picture_classification
or self.pipeline_options.do_picture_description
):
self.keep_backend = True
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
warnings.warn(
"The usage of LegacyStandardPdfPipeline.download_models_hf() is deprecated "
"use instead the utility `docling-tools models download`, or "
"the upstream method docling.utils.models_downloader.download_all()",
DeprecationWarning,
stacklevel=3,
)
output_dir = download_models(output_dir=local_dir, force=force, progress=False)
return output_dir
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance(
options=self.pipeline_options.ocr_options,
enabled=self.pipeline_options.do_ocr,
artifacts_path=artifacts_path,
accelerator_options=self.pipeline_options.accelerator_options,
)
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
return page
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
all_elements = []
all_headers = []
all_body = []
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
for p in conv_res.pages:
if p.assembled is not None:
for el in p.assembled.body:
all_body.append(el)
for el in p.assembled.headers:
all_headers.append(el)
for el in p.assembled.elements:
all_elements.append(el)
conv_res.assembled = AssembledUnit(
elements=all_elements, headers=all_headers, body=all_body
)
conv_res.document = self.reading_order_model(conv_res)
# Generate page images in the output
if self.pipeline_options.generate_page_images:
for page in conv_res.pages:
assert page.image is not None
page_no = page.page_no + 1
conv_res.document.pages[page_no].image = ImageRef.from_pil(
page.image, dpi=int(72 * self.pipeline_options.images_scale)
)
# Generate images of the requested element types
with warnings.catch_warnings(): # deprecated generate_table_images
warnings.filterwarnings("ignore", category=DeprecationWarning)
if (
self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
):
scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items():
if not isinstance(element, DocItem) or len(element.prov) == 0:
continue
if (
isinstance(element, PictureItem)
and self.pipeline_options.generate_picture_images
) or (
isinstance(element, TableItem)
and self.pipeline_options.generate_table_images
):
page_ix = element.prov[0].page_no - 1
page = next(
(p for p in conv_res.pages if p.page_no == page_ix),
cast("Page", None),
)
assert page is not None
assert page.size is not None
assert page.image is not None
crop_bbox = (
element.prov[0]
.bbox.scaled(scale=scale)
.to_top_left_origin(
page_height=page.size.height * scale
)
)
cropped_im = page.image.crop(crop_bbox.as_tuple())
element.image = ImageRef.from_pil(
cropped_im, dpi=int(72 * scale)
)
# Aggregate confidence values for document:
if len(conv_res.pages) > 0:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="Mean of empty slice|All-NaN slice encountered",
)
conv_res.confidence.layout_score = float(
np.nanmean(
[c.layout_score for c in conv_res.confidence.pages.values()]
)
)
conv_res.confidence.parse_score = float(
np.nanquantile(
[c.parse_score for c in conv_res.confidence.pages.values()],
q=0.1, # parse score should relate to worst 10% of pages.
)
)
conv_res.confidence.table_score = float(
np.nanmean(
[c.table_score for c in conv_res.confidence.pages.values()]
)
)
conv_res.confidence.ocr_score = float(
np.nanmean(
[c.ocr_score for c in conv_res.confidence.pages.values()]
)
)
return conv_res
@classmethod
def get_default_options(cls) -> PdfPipelineOptions:
return PdfPipelineOptions()
@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend):
return isinstance(backend, PdfDocumentBackend)

View File

@@ -1,19 +1,39 @@
"""Thread-safe, production-ready PDF pipeline
================================================
A self-contained, thread-safe PDF conversion pipeline exploiting parallelism between pipeline stages and models.
* **Per-run isolation** - every :py:meth:`execute` call uses its own bounded queues and worker
threads so that concurrent invocations never share mutable state.
* **Deterministic run identifiers** - pages are tracked with an internal *run-id* instead of
relying on :pyfunc:`id`, which may clash after garbage collection.
* **Explicit back-pressure & shutdown** - producers block on full queues; queue *close()*
propagates downstream so stages terminate deterministically without sentinels.
* **Minimal shared state** - heavyweight models are initialised once per pipeline instance
and only read by worker threads; no runtime mutability is exposed.
* **Strict typing & clean API usage** - code is fully annotated and respects *coding_rules.md*.
"""
from __future__ import annotations
import itertools
import logging
import threading
import time
import warnings
from collections import defaultdict, deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, cast
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, cast
import numpy as np
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import AssembledUnit, Page
from docling.datamodel.base_models import AssembledUnit, ConversionStatus, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.layout_model_specs import LayoutModelConfig
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.pipeline_options import ThreadedPdfPipelineOptions
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.factories import get_ocr_factory
from docling.models.layout_model import LayoutModel
@@ -24,132 +44,588 @@ from docling.models.page_preprocessing_model import (
)
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.model_downloader import download_models
from docling.pipeline.base_pipeline import ConvertPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import chunkify
_log = logging.getLogger(__name__)
# ──────────────────────────────────────────────────────────────────────────────
# Helper data structures
# ──────────────────────────────────────────────────────────────────────────────
class StandardPdfPipeline(PaginatedPipeline):
def __init__(self, pipeline_options: PdfPipelineOptions):
@dataclass
class ThreadedItem:
"""Envelope that travels between pipeline stages."""
payload: Optional[Page]
run_id: int # Unique per *execute* call, monotonic across pipeline instance
page_no: int
conv_res: ConversionResult
error: Optional[Exception] = None
is_failed: bool = False
@dataclass
class ProcessingResult:
"""Aggregated outcome of a pipeline run."""
pages: List[Page] = field(default_factory=list)
failed_pages: List[Tuple[int, Exception]] = field(default_factory=list)
total_expected: int = 0
@property
def success_count(self) -> int:
return len(self.pages)
@property
def failure_count(self) -> int:
return len(self.failed_pages)
@property
def is_partial_success(self) -> bool:
return 0 < self.success_count < self.total_expected
@property
def is_complete_failure(self) -> bool:
return self.success_count == 0 and self.failure_count > 0
class ThreadedQueue:
"""Bounded queue with blocking put/ get_batch and explicit *close()* semantics."""
__slots__ = ("_closed", "_items", "_lock", "_max", "_not_empty", "_not_full")
def __init__(self, max_size: int) -> None:
self._max: int = max_size
self._items: deque[ThreadedItem] = deque()
self._lock = threading.Lock()
self._not_full = threading.Condition(self._lock)
self._not_empty = threading.Condition(self._lock)
self._closed = False
# ---------------------------------------------------------------- put()
def put(self, item: ThreadedItem, timeout: Optional[float] | None = None) -> bool:
"""Block until queue accepts *item* or is closed. Returns *False* if closed."""
with self._not_full:
if self._closed:
return False
start = time.monotonic()
while len(self._items) >= self._max and not self._closed:
if timeout is not None:
remaining = timeout - (time.monotonic() - start)
if remaining <= 0:
return False
self._not_full.wait(remaining)
else:
self._not_full.wait()
if self._closed:
return False
self._items.append(item)
self._not_empty.notify()
return True
# ------------------------------------------------------------ get_batch()
def get_batch(
self, size: int, timeout: Optional[float] | None = None
) -> List[ThreadedItem]:
"""Return up to *size* items. Blocks until ≥1 item present or queue closed/timeout."""
with self._not_empty:
start = time.monotonic()
while not self._items and not self._closed:
if timeout is not None:
remaining = timeout - (time.monotonic() - start)
if remaining <= 0:
return []
self._not_empty.wait(remaining)
else:
self._not_empty.wait()
batch: List[ThreadedItem] = []
while self._items and len(batch) < size:
batch.append(self._items.popleft())
if batch:
self._not_full.notify_all()
return batch
# ---------------------------------------------------------------- close()
def close(self) -> None:
with self._lock:
self._closed = True
self._not_empty.notify_all()
self._not_full.notify_all()
# -------------------------------------------------------------- property
@property
def closed(self) -> bool:
return self._closed
class ThreadedPipelineStage:
"""A single pipeline stage backed by one worker thread."""
def __init__(
self,
*,
name: str,
model: Any,
batch_size: int,
batch_timeout: float,
queue_max_size: int,
postprocess: Optional[Callable[[ThreadedItem], None]] = None,
) -> None:
self.name = name
self.model = model
self.batch_size = batch_size
self.batch_timeout = batch_timeout
self.input_queue = ThreadedQueue(queue_max_size)
self._outputs: list[ThreadedQueue] = []
self._thread: Optional[threading.Thread] = None
self._running = False
self._postprocess = postprocess
# ---------------------------------------------------------------- wiring
def add_output_queue(self, q: ThreadedQueue) -> None:
self._outputs.append(q)
# -------------------------------------------------------------- lifecycle
def start(self) -> None:
if self._running:
return
self._running = True
self._thread = threading.Thread(
target=self._run, name=f"Stage-{self.name}", daemon=True
)
self._thread.start()
def stop(self) -> None:
if not self._running:
return
self._running = False
self.input_queue.close()
if self._thread is not None:
self._thread.join(timeout=30.0)
if self._thread.is_alive():
_log.warning("Stage %s did not terminate cleanly within 30s", self.name)
# ------------------------------------------------------------------ _run
def _run(self) -> None:
try:
while self._running:
batch = self.input_queue.get_batch(self.batch_size, self.batch_timeout)
if not batch and self.input_queue.closed:
break
processed = self._process_batch(batch)
self._emit(processed)
except Exception: # pragma: no cover - top-level guard
_log.exception("Fatal error in stage %s", self.name)
finally:
for q in self._outputs:
q.close()
# ----------------------------------------------------- _process_batch()
def _process_batch(self, batch: Sequence[ThreadedItem]) -> list[ThreadedItem]:
"""Run *model* on *batch* grouped by run_id to maximise batching."""
groups: dict[int, list[ThreadedItem]] = defaultdict(list)
for itm in batch:
groups[itm.run_id].append(itm)
result: list[ThreadedItem] = []
for rid, items in groups.items():
good: list[ThreadedItem] = [i for i in items if not i.is_failed]
if not good:
result.extend(items)
continue
try:
# Filter out None payloads and ensure type safety
pages_with_payloads = [
(i, i.payload) for i in good if i.payload is not None
]
if len(pages_with_payloads) != len(good):
# Some items have None payloads, mark all as failed
for it in items:
it.is_failed = True
it.error = RuntimeError("Page payload is None")
result.extend(items)
continue
pages: List[Page] = [payload for _, payload in pages_with_payloads]
processed_pages = list(self.model(good[0].conv_res, pages)) # type: ignore[arg-type]
if len(processed_pages) != len(pages): # strict mismatch guard
raise RuntimeError(
f"Model {self.name} returned wrong number of pages"
)
for idx, page in enumerate(processed_pages):
result.append(
ThreadedItem(
payload=page,
run_id=rid,
page_no=good[idx].page_no,
conv_res=good[idx].conv_res,
)
)
except Exception as exc:
_log.error("Stage %s failed for run %d: %s", self.name, rid, exc)
for it in items:
it.is_failed = True
it.error = exc
result.extend(items)
return result
# -------------------------------------------------------------- _emit()
def _emit(self, items: Iterable[ThreadedItem]) -> None:
for item in items:
if self._postprocess is not None:
self._postprocess(item)
for q in self._outputs:
if not q.put(item):
_log.error("Output queue closed while emitting from %s", self.name)
class PreprocessThreadedStage(ThreadedPipelineStage):
"""Pipeline stage that lazily loads PDF backends just-in-time."""
def __init__(
self,
*,
batch_timeout: float,
queue_max_size: int,
model: Any,
) -> None:
super().__init__(
name="preprocess",
model=model,
batch_size=1,
batch_timeout=batch_timeout,
queue_max_size=queue_max_size,
)
def _process_batch(self, batch: Sequence[ThreadedItem]) -> list[ThreadedItem]:
groups: dict[int, list[ThreadedItem]] = defaultdict(list)
for itm in batch:
groups[itm.run_id].append(itm)
result: list[ThreadedItem] = []
for rid, items in groups.items():
good = [i for i in items if not i.is_failed]
if not good:
result.extend(items)
continue
try:
pages_with_payloads: list[tuple[ThreadedItem, Page]] = []
for it in good:
page = it.payload
if page is None:
raise RuntimeError("Page payload is None")
if page._backend is None:
backend = it.conv_res.input._backend
assert isinstance(backend, PdfDocumentBackend), (
"Threaded pipeline only supports PdfDocumentBackend."
)
page_backend = backend.load_page(page.page_no)
page._backend = page_backend
if page_backend.is_valid():
page.size = page_backend.get_size()
pages_with_payloads.append((it, page))
pages = [payload for _, payload in pages_with_payloads]
processed_pages = list(
self.model(good[0].conv_res, pages) # type: ignore[arg-type]
)
if len(processed_pages) != len(pages):
raise RuntimeError(
"PagePreprocessingModel returned unexpected number of pages"
)
for idx, processed_page in enumerate(processed_pages):
result.append(
ThreadedItem(
payload=processed_page,
run_id=rid,
page_no=good[idx].page_no,
conv_res=good[idx].conv_res,
)
)
except Exception as exc:
_log.error("Stage preprocess failed for run %d: %s", rid, exc)
for it in items:
it.is_failed = True
it.error = exc
result.extend(items)
return result
@dataclass
class RunContext:
"""Wiring for a single *execute* call."""
stages: list[ThreadedPipelineStage]
first_stage: ThreadedPipelineStage
output_queue: ThreadedQueue
# ──────────────────────────────────────────────────────────────────────────────
# Main pipeline
# ──────────────────────────────────────────────────────────────────────────────
class StandardPdfPipeline(ConvertPipeline):
"""High-performance PDF pipeline with multi-threaded stages."""
def __init__(self, pipeline_options: ThreadedPdfPipelineOptions) -> None:
super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions
self.pipeline_options: ThreadedPdfPipelineOptions = pipeline_options
self._run_seq = itertools.count(1) # deterministic, monotonic run ids
with warnings.catch_warnings(): # deprecated generate_table_images
warnings.filterwarnings("ignore", category=DeprecationWarning)
self.keep_images = (
self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
# initialise heavy models once
self._init_models()
# ────────────────────────────────────────────────────────────────────────
# Heavy-model initialisation & helpers
# ────────────────────────────────────────────────────────────────────────
def _init_models(self) -> None:
art_path = self.artifacts_path
self.keep_images = (
self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
)
self.preprocessing_model = PagePreprocessingModel(
options=PagePreprocessingOptions(
images_scale=self.pipeline_options.images_scale
)
)
self.ocr_model = self._make_ocr_model(art_path)
self.layout_model = LayoutModel(
artifacts_path=art_path,
accelerator_options=self.pipeline_options.accelerator_options,
options=self.pipeline_options.layout_options,
)
self.table_model = TableStructureModel(
enabled=self.pipeline_options.do_table_structure,
artifacts_path=art_path,
options=self.pipeline_options.table_structure_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
self.assemble_model = PageAssembleModel(options=PageAssembleOptions())
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
ocr_model = self.get_ocr_model(artifacts_path=self.artifacts_path)
self.build_pipe = [
# Pre-processing
PagePreprocessingModel(
options=PagePreprocessingOptions(
images_scale=pipeline_options.images_scale,
)
),
# OCR
ocr_model,
# Layout model
LayoutModel(
artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
options=pipeline_options.layout_options,
),
# Table structure model
TableStructureModel(
enabled=pipeline_options.do_table_structure,
artifacts_path=self.artifacts_path,
options=pipeline_options.table_structure_options,
accelerator_options=pipeline_options.accelerator_options,
),
# Page assemble
PageAssembleModel(options=PageAssembleOptions()),
]
# --- optional enrichment ------------------------------------------------
self.enrichment_pipe = [
# Code Formula Enrichment Model
CodeFormulaModel(
enabled=pipeline_options.do_code_enrichment
or pipeline_options.do_formula_enrichment,
enabled=self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_formula_enrichment,
artifacts_path=self.artifacts_path,
options=CodeFormulaModelOptions(
do_code_enrichment=pipeline_options.do_code_enrichment,
do_formula_enrichment=pipeline_options.do_formula_enrichment,
do_code_enrichment=self.pipeline_options.do_code_enrichment,
do_formula_enrichment=self.pipeline_options.do_formula_enrichment,
),
accelerator_options=pipeline_options.accelerator_options,
accelerator_options=self.pipeline_options.accelerator_options,
),
*self.enrichment_pipe,
]
if (
self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_picture_classification
or self.pipeline_options.do_picture_description
):
self.keep_backend = True
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
warnings.warn(
"The usage of StandardPdfPipeline.download_models_hf() is deprecated "
"use instead the utility `docling-tools models download`, or "
"the upstream method docling.utils.models_downloader.download_all()",
DeprecationWarning,
stacklevel=3,
self.keep_backend = any(
(
self.pipeline_options.do_formula_enrichment,
self.pipeline_options.do_code_enrichment,
self.pipeline_options.do_picture_classification,
self.pipeline_options.do_picture_description,
)
)
output_dir = download_models(output_dir=local_dir, force=force, progress=False)
return output_dir
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
# ---------------------------------------------------------------- helpers
def _make_ocr_model(self, art_path: Optional[Path]) -> Any:
factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance(
options=self.pipeline_options.ocr_options,
enabled=self.pipeline_options.do_ocr,
artifacts_path=artifacts_path,
artifacts_path=art_path,
accelerator_options=self.pipeline_options.accelerator_options,
)
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
def _release_page_resources(self, item: ThreadedItem) -> None:
page = item.payload
if page is None:
return
if not self.keep_images:
page._image_cache = {}
if not self.keep_backend and page._backend is not None:
page._backend.unload()
page._backend = None
if not self.pipeline_options.generate_parsed_pages:
page.parsed_page = None
return page
# ────────────────────────────────────────────────────────────────────────
# Build - thread pipeline
# ────────────────────────────────────────────────────────────────────────
def _create_run_ctx(self) -> RunContext:
opts = self.pipeline_options
preprocess = PreprocessThreadedStage(
batch_timeout=opts.batch_polling_interval_seconds,
queue_max_size=opts.queue_max_size,
model=self.preprocessing_model,
)
ocr = ThreadedPipelineStage(
name="ocr",
model=self.ocr_model,
batch_size=opts.ocr_batch_size,
batch_timeout=opts.batch_polling_interval_seconds,
queue_max_size=opts.queue_max_size,
)
layout = ThreadedPipelineStage(
name="layout",
model=self.layout_model,
batch_size=opts.layout_batch_size,
batch_timeout=opts.batch_polling_interval_seconds,
queue_max_size=opts.queue_max_size,
)
table = ThreadedPipelineStage(
name="table",
model=self.table_model,
batch_size=opts.table_batch_size,
batch_timeout=opts.batch_polling_interval_seconds,
queue_max_size=opts.queue_max_size,
)
assemble = ThreadedPipelineStage(
name="assemble",
model=self.assemble_model,
batch_size=1,
batch_timeout=opts.batch_polling_interval_seconds,
queue_max_size=opts.queue_max_size,
postprocess=self._release_page_resources,
)
# wire stages
output_q = ThreadedQueue(opts.queue_max_size)
preprocess.add_output_queue(ocr.input_queue)
ocr.add_output_queue(layout.input_queue)
layout.add_output_queue(table.input_queue)
table.add_output_queue(assemble.input_queue)
assemble.add_output_queue(output_q)
stages = [preprocess, ocr, layout, table, assemble]
return RunContext(stages=stages, first_stage=preprocess, output_queue=output_q)
# --------------------------------------------------------------------- build
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
"""Stream-build the document while interleaving producer and consumer work."""
run_id = next(self._run_seq)
assert isinstance(conv_res.input._backend, PdfDocumentBackend)
# Collect page placeholders; backends are loaded lazily in preprocess stage
start_page, end_page = conv_res.input.limits.page_range
pages: list[Page] = []
for i in range(conv_res.input.page_count):
if start_page - 1 <= i <= end_page - 1:
page = Page(page_no=i)
conv_res.pages.append(page)
pages.append(page)
if not pages:
conv_res.status = ConversionStatus.FAILURE
return conv_res
total_pages: int = len(pages)
ctx: RunContext = self._create_run_ctx()
for st in ctx.stages:
st.start()
proc = ProcessingResult(total_expected=total_pages)
fed_idx: int = 0 # number of pages successfully queued
batch_size: int = 32 # drain chunk
try:
while proc.success_count + proc.failure_count < total_pages:
# 1) feed - try to enqueue until the first queue is full
while fed_idx < total_pages:
ok = ctx.first_stage.input_queue.put(
ThreadedItem(
payload=pages[fed_idx],
run_id=run_id,
page_no=pages[fed_idx].page_no,
conv_res=conv_res,
),
timeout=0.0, # non-blocking try-put
)
if ok:
fed_idx += 1
if fed_idx == total_pages:
ctx.first_stage.input_queue.close()
else: # queue full - switch to draining
break
# 2) drain - pull whatever is ready from the output side
out_batch = ctx.output_queue.get_batch(batch_size, timeout=0.05)
for itm in out_batch:
if itm.run_id != run_id:
continue
if itm.is_failed or itm.error:
proc.failed_pages.append(
(itm.page_no, itm.error or RuntimeError("unknown error"))
)
else:
assert itm.payload is not None
proc.pages.append(itm.payload)
# 3) failure safety - downstream closed early -> mark missing pages failed
if not out_batch and ctx.output_queue.closed:
missing = total_pages - (proc.success_count + proc.failure_count)
if missing > 0:
proc.failed_pages.extend(
[(-1, RuntimeError("pipeline terminated early"))] * missing
)
break
finally:
for st in ctx.stages:
st.stop()
ctx.output_queue.close()
self._integrate_results(conv_res, proc)
return conv_res
# ---------------------------------------------------- integrate_results()
def _integrate_results(
self, conv_res: ConversionResult, proc: ProcessingResult
) -> None:
page_map = {p.page_no: p for p in proc.pages}
conv_res.pages = [
page_map.get(p.page_no, p)
for p in conv_res.pages
if p.page_no in page_map
or not any(fp == p.page_no for fp, _ in proc.failed_pages)
]
if proc.is_complete_failure:
conv_res.status = ConversionStatus.FAILURE
elif proc.is_partial_success:
conv_res.status = ConversionStatus.PARTIAL_SUCCESS
else:
conv_res.status = ConversionStatus.SUCCESS
if not self.keep_images:
for p in conv_res.pages:
p._image_cache = {}
for p in conv_res.pages:
if not self.keep_backend and p._backend is not None:
p._backend.unload()
if not self.pipeline_options.generate_parsed_pages:
del p.parsed_page
p.parsed_page = None
# ---------------------------------------------------------------- assemble
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
all_elements = []
all_headers = []
all_body = []
elements, headers, body = [], [], []
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
for p in conv_res.pages:
if p.assembled is not None:
for el in p.assembled.body:
all_body.append(el)
for el in p.assembled.headers:
all_headers.append(el)
for el in p.assembled.elements:
all_elements.append(el)
if p.assembled:
elements.extend(p.assembled.elements)
headers.extend(p.assembled.headers)
body.extend(p.assembled.body)
conv_res.assembled = AssembledUnit(
elements=all_elements, headers=all_headers, body=all_body
elements=elements, headers=headers, body=body
)
conv_res.document = self.reading_order_model(conv_res)
# Generate page images in the output
@@ -233,10 +709,21 @@ class StandardPdfPipeline(PaginatedPipeline):
return conv_res
# ---------------------------------------------------------------- misc
@classmethod
def get_default_options(cls) -> PdfPipelineOptions:
return PdfPipelineOptions()
def get_default_options(cls) -> ThreadedPdfPipelineOptions:
return ThreadedPdfPipelineOptions()
@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend):
def is_backend_supported(cls, backend: AbstractDocumentBackend) -> bool:
return isinstance(backend, PdfDocumentBackend)
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
return conv_res.status
def _unload(self, conv_res: ConversionResult) -> None:
for p in conv_res.pages:
if p._backend is not None:
p._backend.unload()
if conv_res.input._backend:
conv_res.input._backend.unload()

View File

@@ -1,647 +1,5 @@
# threaded_standard_pdf_pipeline.py
"""Thread-safe, production-ready PDF pipeline
================================================
A self-contained, thread-safe PDF conversion pipeline exploiting parallelism between pipeline stages and models.
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
* **Per-run isolation** - every :py:meth:`execute` call uses its own bounded queues and worker
threads so that concurrent invocations never share mutable state.
* **Deterministic run identifiers** - pages are tracked with an internal *run-id* instead of
relying on :pyfunc:`id`, which may clash after garbage collection.
* **Explicit back-pressure & shutdown** - producers block on full queues; queue *close()*
propagates downstream so stages terminate deterministically without sentinels.
* **Minimal shared state** - heavyweight models are initialised once per pipeline instance
and only read by worker threads; no runtime mutability is exposed.
* **Strict typing & clean API usage** - code is fully annotated and respects *coding_rules.md*.
"""
from __future__ import annotations
import itertools
import logging
import threading
import time
import warnings
from collections import defaultdict, deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Iterable, List, Optional, Sequence, Tuple, cast
import numpy as np
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import AssembledUnit, ConversionStatus, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ThreadedPdfPipelineOptions
from docling.datamodel.settings import settings
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.factories import get_ocr_factory
from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.base_pipeline import ConvertPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import chunkify
_log = logging.getLogger(__name__)
# ──────────────────────────────────────────────────────────────────────────────
# Helper data structures
# ──────────────────────────────────────────────────────────────────────────────
@dataclass
class ThreadedItem:
"""Envelope that travels between pipeline stages."""
payload: Optional[Page]
run_id: int # Unique per *execute* call, monotonic across pipeline instance
page_no: int
conv_res: ConversionResult
error: Optional[Exception] = None
is_failed: bool = False
@dataclass
class ProcessingResult:
"""Aggregated outcome of a pipeline run."""
pages: List[Page] = field(default_factory=list)
failed_pages: List[Tuple[int, Exception]] = field(default_factory=list)
total_expected: int = 0
@property
def success_count(self) -> int:
return len(self.pages)
@property
def failure_count(self) -> int:
return len(self.failed_pages)
@property
def is_partial_success(self) -> bool:
return 0 < self.success_count < self.total_expected
@property
def is_complete_failure(self) -> bool:
return self.success_count == 0 and self.failure_count > 0
class ThreadedQueue:
"""Bounded queue with blocking put/ get_batch and explicit *close()* semantics."""
__slots__ = ("_closed", "_items", "_lock", "_max", "_not_empty", "_not_full")
def __init__(self, max_size: int) -> None:
self._max: int = max_size
self._items: deque[ThreadedItem] = deque()
self._lock = threading.Lock()
self._not_full = threading.Condition(self._lock)
self._not_empty = threading.Condition(self._lock)
self._closed = False
# ---------------------------------------------------------------- put()
def put(self, item: ThreadedItem, timeout: Optional[float] | None = None) -> bool:
"""Block until queue accepts *item* or is closed. Returns *False* if closed."""
with self._not_full:
if self._closed:
return False
start = time.monotonic()
while len(self._items) >= self._max and not self._closed:
if timeout is not None:
remaining = timeout - (time.monotonic() - start)
if remaining <= 0:
return False
self._not_full.wait(remaining)
else:
self._not_full.wait()
if self._closed:
return False
self._items.append(item)
self._not_empty.notify()
return True
# ------------------------------------------------------------ get_batch()
def get_batch(
self, size: int, timeout: Optional[float] | None = None
) -> List[ThreadedItem]:
"""Return up to *size* items. Blocks until ≥1 item present or queue closed/timeout."""
with self._not_empty:
start = time.monotonic()
while not self._items and not self._closed:
if timeout is not None:
remaining = timeout - (time.monotonic() - start)
if remaining <= 0:
return []
self._not_empty.wait(remaining)
else:
self._not_empty.wait()
batch: List[ThreadedItem] = []
while self._items and len(batch) < size:
batch.append(self._items.popleft())
if batch:
self._not_full.notify_all()
return batch
# ---------------------------------------------------------------- close()
def close(self) -> None:
with self._lock:
self._closed = True
self._not_empty.notify_all()
self._not_full.notify_all()
# -------------------------------------------------------------- property
@property
def closed(self) -> bool:
return self._closed
class ThreadedPipelineStage:
"""A single pipeline stage backed by one worker thread."""
def __init__(
self,
*,
name: str,
model: Any,
batch_size: int,
batch_timeout: float,
queue_max_size: int,
) -> None:
self.name = name
self.model = model
self.batch_size = batch_size
self.batch_timeout = batch_timeout
self.input_queue = ThreadedQueue(queue_max_size)
self._outputs: list[ThreadedQueue] = []
self._thread: Optional[threading.Thread] = None
self._running = False
# ---------------------------------------------------------------- wiring
def add_output_queue(self, q: ThreadedQueue) -> None:
self._outputs.append(q)
# -------------------------------------------------------------- lifecycle
def start(self) -> None:
if self._running:
return
self._running = True
self._thread = threading.Thread(
target=self._run, name=f"Stage-{self.name}", daemon=True
)
self._thread.start()
def stop(self) -> None:
if not self._running:
return
self._running = False
self.input_queue.close()
if self._thread is not None:
self._thread.join(timeout=30.0)
if self._thread.is_alive():
_log.warning("Stage %s did not terminate cleanly within 30s", self.name)
# ------------------------------------------------------------------ _run
def _run(self) -> None:
try:
while self._running:
batch = self.input_queue.get_batch(self.batch_size, self.batch_timeout)
if not batch and self.input_queue.closed:
break
processed = self._process_batch(batch)
self._emit(processed)
except Exception: # pragma: no cover - top-level guard
_log.exception("Fatal error in stage %s", self.name)
finally:
for q in self._outputs:
q.close()
# ----------------------------------------------------- _process_batch()
def _process_batch(self, batch: Sequence[ThreadedItem]) -> list[ThreadedItem]:
"""Run *model* on *batch* grouped by run_id to maximise batching."""
groups: dict[int, list[ThreadedItem]] = defaultdict(list)
for itm in batch:
groups[itm.run_id].append(itm)
result: list[ThreadedItem] = []
for rid, items in groups.items():
good: list[ThreadedItem] = [i for i in items if not i.is_failed]
if not good:
result.extend(items)
continue
try:
# Filter out None payloads and ensure type safety
pages_with_payloads = [
(i, i.payload) for i in good if i.payload is not None
]
if len(pages_with_payloads) != len(good):
# Some items have None payloads, mark all as failed
for it in items:
it.is_failed = True
it.error = RuntimeError("Page payload is None")
result.extend(items)
continue
pages: List[Page] = [payload for _, payload in pages_with_payloads]
processed_pages = list(self.model(good[0].conv_res, pages)) # type: ignore[arg-type]
if len(processed_pages) != len(pages): # strict mismatch guard
raise RuntimeError(
f"Model {self.name} returned wrong number of pages"
)
for idx, page in enumerate(processed_pages):
result.append(
ThreadedItem(
payload=page,
run_id=rid,
page_no=good[idx].page_no,
conv_res=good[idx].conv_res,
)
)
except Exception as exc:
_log.error("Stage %s failed for run %d: %s", self.name, rid, exc)
for it in items:
it.is_failed = True
it.error = exc
result.extend(items)
return result
# -------------------------------------------------------------- _emit()
def _emit(self, items: Iterable[ThreadedItem]) -> None:
for item in items:
for q in self._outputs:
if not q.put(item):
_log.error("Output queue closed while emitting from %s", self.name)
@dataclass
class RunContext:
"""Wiring for a single *execute* call."""
stages: list[ThreadedPipelineStage]
first_stage: ThreadedPipelineStage
output_queue: ThreadedQueue
# ──────────────────────────────────────────────────────────────────────────────
# Main pipeline
# ──────────────────────────────────────────────────────────────────────────────
class ThreadedStandardPdfPipeline(ConvertPipeline):
"""High-performance PDF pipeline with multi-threaded stages."""
def __init__(self, pipeline_options: ThreadedPdfPipelineOptions) -> None:
super().__init__(pipeline_options)
self.pipeline_options: ThreadedPdfPipelineOptions = pipeline_options
self._run_seq = itertools.count(1) # deterministic, monotonic run ids
# initialise heavy models once
self._init_models()
# ────────────────────────────────────────────────────────────────────────
# Heavy-model initialisation & helpers
# ────────────────────────────────────────────────────────────────────────
def _init_models(self) -> None:
art_path = self.artifacts_path
self.keep_images = (
self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
)
self.preprocessing_model = PagePreprocessingModel(
options=PagePreprocessingOptions(
images_scale=self.pipeline_options.images_scale
)
)
self.ocr_model = self._make_ocr_model(art_path)
self.layout_model = LayoutModel(
artifacts_path=art_path,
accelerator_options=self.pipeline_options.accelerator_options,
options=self.pipeline_options.layout_options,
)
self.table_model = TableStructureModel(
enabled=self.pipeline_options.do_table_structure,
artifacts_path=art_path,
options=self.pipeline_options.table_structure_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
self.assemble_model = PageAssembleModel(options=PageAssembleOptions())
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
# --- optional enrichment ------------------------------------------------
self.enrichment_pipe = [
# Code Formula Enrichment Model
CodeFormulaModel(
enabled=self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_formula_enrichment,
artifacts_path=self.artifacts_path,
options=CodeFormulaModelOptions(
do_code_enrichment=self.pipeline_options.do_code_enrichment,
do_formula_enrichment=self.pipeline_options.do_formula_enrichment,
),
accelerator_options=self.pipeline_options.accelerator_options,
),
*self.enrichment_pipe,
]
self.keep_backend = any(
(
self.pipeline_options.do_formula_enrichment,
self.pipeline_options.do_code_enrichment,
self.pipeline_options.do_picture_classification,
self.pipeline_options.do_picture_description,
)
)
# ---------------------------------------------------------------- helpers
def _make_ocr_model(self, art_path: Optional[Path]) -> Any:
factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance(
options=self.pipeline_options.ocr_options,
enabled=self.pipeline_options.do_ocr,
artifacts_path=art_path,
accelerator_options=self.pipeline_options.accelerator_options,
)
# ────────────────────────────────────────────────────────────────────────
# Build - thread pipeline
# ────────────────────────────────────────────────────────────────────────
def _create_run_ctx(self) -> RunContext:
opts = self.pipeline_options
preprocess = ThreadedPipelineStage(
name="preprocess",
model=self.preprocessing_model,
batch_size=1,
batch_timeout=opts.batch_timeout_seconds,
queue_max_size=opts.queue_max_size,
)
ocr = ThreadedPipelineStage(
name="ocr",
model=self.ocr_model,
batch_size=opts.ocr_batch_size,
batch_timeout=opts.batch_timeout_seconds,
queue_max_size=opts.queue_max_size,
)
layout = ThreadedPipelineStage(
name="layout",
model=self.layout_model,
batch_size=opts.layout_batch_size,
batch_timeout=opts.batch_timeout_seconds,
queue_max_size=opts.queue_max_size,
)
table = ThreadedPipelineStage(
name="table",
model=self.table_model,
batch_size=opts.table_batch_size,
batch_timeout=opts.batch_timeout_seconds,
queue_max_size=opts.queue_max_size,
)
assemble = ThreadedPipelineStage(
name="assemble",
model=self.assemble_model,
batch_size=1,
batch_timeout=opts.batch_timeout_seconds,
queue_max_size=opts.queue_max_size,
)
# wire stages
output_q = ThreadedQueue(opts.queue_max_size)
preprocess.add_output_queue(ocr.input_queue)
ocr.add_output_queue(layout.input_queue)
layout.add_output_queue(table.input_queue)
table.add_output_queue(assemble.input_queue)
assemble.add_output_queue(output_q)
stages = [preprocess, ocr, layout, table, assemble]
return RunContext(stages=stages, first_stage=preprocess, output_queue=output_q)
# --------------------------------------------------------------------- build
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
"""Stream-build the document while interleaving producer and consumer work."""
run_id = next(self._run_seq)
assert isinstance(conv_res.input._backend, PdfDocumentBackend)
backend = conv_res.input._backend
# preload & initialise pages -------------------------------------------------------------
start_page, end_page = conv_res.input.limits.page_range
pages: list[Page] = []
for i in range(conv_res.input.page_count):
if start_page - 1 <= i <= end_page - 1:
page = Page(page_no=i)
page._backend = backend.load_page(i)
if page._backend and page._backend.is_valid():
page.size = page._backend.get_size()
conv_res.pages.append(page)
pages.append(page)
if not pages:
conv_res.status = ConversionStatus.FAILURE
return conv_res
total_pages: int = len(pages)
ctx: RunContext = self._create_run_ctx()
for st in ctx.stages:
st.start()
proc = ProcessingResult(total_expected=total_pages)
fed_idx: int = 0 # number of pages successfully queued
batch_size: int = 32 # drain chunk
try:
while proc.success_count + proc.failure_count < total_pages:
# 1) feed - try to enqueue until the first queue is full
while fed_idx < total_pages:
ok = ctx.first_stage.input_queue.put(
ThreadedItem(
payload=pages[fed_idx],
run_id=run_id,
page_no=pages[fed_idx].page_no,
conv_res=conv_res,
),
timeout=0.0, # non-blocking try-put
)
if ok:
fed_idx += 1
if fed_idx == total_pages:
ctx.first_stage.input_queue.close()
else: # queue full - switch to draining
break
# 2) drain - pull whatever is ready from the output side
out_batch = ctx.output_queue.get_batch(batch_size, timeout=0.05)
for itm in out_batch:
if itm.run_id != run_id:
continue
if itm.is_failed or itm.error:
proc.failed_pages.append(
(itm.page_no, itm.error or RuntimeError("unknown error"))
)
else:
assert itm.payload is not None
proc.pages.append(itm.payload)
# 3) failure safety - downstream closed early -> mark missing pages failed
if not out_batch and ctx.output_queue.closed:
missing = total_pages - (proc.success_count + proc.failure_count)
if missing > 0:
proc.failed_pages.extend(
[(-1, RuntimeError("pipeline terminated early"))] * missing
)
break
finally:
for st in ctx.stages:
st.stop()
ctx.output_queue.close()
self._integrate_results(conv_res, proc)
return conv_res
# ---------------------------------------------------- integrate_results()
def _integrate_results(
self, conv_res: ConversionResult, proc: ProcessingResult
) -> None:
page_map = {p.page_no: p for p in proc.pages}
conv_res.pages = [
page_map.get(p.page_no, p)
for p in conv_res.pages
if p.page_no in page_map
or not any(fp == p.page_no for fp, _ in proc.failed_pages)
]
if proc.is_complete_failure:
conv_res.status = ConversionStatus.FAILURE
elif proc.is_partial_success:
conv_res.status = ConversionStatus.PARTIAL_SUCCESS
else:
conv_res.status = ConversionStatus.SUCCESS
if not self.keep_images:
for p in conv_res.pages:
p._image_cache = {}
for p in conv_res.pages:
if not self.keep_backend and p._backend is not None:
p._backend.unload()
if not self.pipeline_options.generate_parsed_pages:
del p.parsed_page
p.parsed_page = None
# ---------------------------------------------------------------- assemble
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
elements, headers, body = [], [], []
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
for p in conv_res.pages:
if p.assembled:
elements.extend(p.assembled.elements)
headers.extend(p.assembled.headers)
body.extend(p.assembled.body)
conv_res.assembled = AssembledUnit(
elements=elements, headers=headers, body=body
)
conv_res.document = self.reading_order_model(conv_res)
# Generate page images in the output
if self.pipeline_options.generate_page_images:
for page in conv_res.pages:
assert page.image is not None
page_no = page.page_no + 1
conv_res.document.pages[page_no].image = ImageRef.from_pil(
page.image, dpi=int(72 * self.pipeline_options.images_scale)
)
# Generate images of the requested element types
with warnings.catch_warnings(): # deprecated generate_table_images
warnings.filterwarnings("ignore", category=DeprecationWarning)
if (
self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
):
scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items():
if not isinstance(element, DocItem) or len(element.prov) == 0:
continue
if (
isinstance(element, PictureItem)
and self.pipeline_options.generate_picture_images
) or (
isinstance(element, TableItem)
and self.pipeline_options.generate_table_images
):
page_ix = element.prov[0].page_no - 1
page = next(
(p for p in conv_res.pages if p.page_no == page_ix),
cast("Page", None),
)
assert page is not None
assert page.size is not None
assert page.image is not None
crop_bbox = (
element.prov[0]
.bbox.scaled(scale=scale)
.to_top_left_origin(
page_height=page.size.height * scale
)
)
cropped_im = page.image.crop(crop_bbox.as_tuple())
element.image = ImageRef.from_pil(
cropped_im, dpi=int(72 * scale)
)
# Aggregate confidence values for document:
if len(conv_res.pages) > 0:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="Mean of empty slice|All-NaN slice encountered",
)
conv_res.confidence.layout_score = float(
np.nanmean(
[c.layout_score for c in conv_res.confidence.pages.values()]
)
)
conv_res.confidence.parse_score = float(
np.nanquantile(
[c.parse_score for c in conv_res.confidence.pages.values()],
q=0.1, # parse score should relate to worst 10% of pages.
)
)
conv_res.confidence.table_score = float(
np.nanmean(
[c.table_score for c in conv_res.confidence.pages.values()]
)
)
conv_res.confidence.ocr_score = float(
np.nanmean(
[c.ocr_score for c in conv_res.confidence.pages.values()]
)
)
return conv_res
# ---------------------------------------------------------------- misc
@classmethod
def get_default_options(cls) -> ThreadedPdfPipelineOptions:
return ThreadedPdfPipelineOptions()
@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend) -> bool:
return isinstance(backend, PdfDocumentBackend)
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
return conv_res.status
def _unload(self, conv_res: ConversionResult) -> None:
for p in conv_res.pages:
if p._backend is not None:
p._backend.unload()
if conv_res.input._backend:
conv_res.input._backend.unload()
class ThreadedStandardPdfPipeline(StandardPdfPipeline):
"""Backwards compatible import for ThreadedStandardPdfPipeline."""

View File

@@ -42,7 +42,7 @@ def test_threaded_pipeline_multiple_documents():
layout_batch_size=1,
table_batch_size=1,
ocr_batch_size=1,
batch_timeout_seconds=1.0,
batch_polling_interval_seconds=1.0,
do_table_structure=do_ts,
do_ocr=do_ocr,
),

5
uv.lock generated
View File

@@ -5943,6 +5943,9 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/8a/b35a615ae6f04550d696bb179c414538b3b477999435fdd4ad75b76139e4/pybase64-1.4.2-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:a370dea7b1cee2a36a4d5445d4e09cc243816c5bc8def61f602db5a6f5438e52", size = 54320, upload-time = "2025-07-27T13:03:27.495Z" },
{ url = "https://files.pythonhosted.org/packages/d3/a9/8bd4f9bcc53689f1b457ecefed1eaa080e4949d65a62c31a38b7253d5226/pybase64-1.4.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9aa4de83f02e462a6f4e066811c71d6af31b52d7484de635582d0e3ec3d6cc3e", size = 56482, upload-time = "2025-07-27T13:03:28.942Z" },
{ url = "https://files.pythonhosted.org/packages/75/e5/4a7735b54a1191f61c3f5c2952212c85c2d6b06eb5fb3671c7603395f70c/pybase64-1.4.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83a1c2f9ed00fee8f064d548c8654a480741131f280e5750bb32475b7ec8ee38", size = 70959, upload-time = "2025-07-27T13:03:30.171Z" },
{ url = "https://files.pythonhosted.org/packages/f4/56/5337f27a8b8d2d6693f46f7b36bae47895e5820bfa259b0072574a4e1057/pybase64-1.4.2-cp313-cp313-android_21_arm64_v8a.whl", hash = "sha256:0f331aa59549de21f690b6ccc79360ffed1155c3cfbc852eb5c097c0b8565a2b", size = 33888, upload-time = "2025-07-27T13:03:35.698Z" },
{ url = "https://files.pythonhosted.org/packages/e3/ff/470768f0fe6de0aa302a8cb1bdf2f9f5cffc3f69e60466153be68bc953aa/pybase64-1.4.2-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:69d3f0445b0faeef7bb7f93bf8c18d850785e2a77f12835f49e524cc54af04e7", size = 30914, upload-time = "2025-07-27T13:03:38.475Z" },
{ url = "https://files.pythonhosted.org/packages/75/6b/d328736662665e0892409dc410353ebef175b1be5eb6bab1dad579efa6df/pybase64-1.4.2-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:2372b257b1f4dd512f317fb27e77d313afd137334de64c87de8374027aacd88a", size = 31380, upload-time = "2025-07-27T13:03:39.7Z" },
{ url = "https://files.pythonhosted.org/packages/ca/96/7ff718f87c67f4147c181b73d0928897cefa17dc75d7abc6e37730d5908f/pybase64-1.4.2-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:fb794502b4b1ec91c4ca5d283ae71aef65e3de7721057bd9e2b3ec79f7a62d7d", size = 38230, upload-time = "2025-07-27T13:03:41.637Z" },
{ url = "https://files.pythonhosted.org/packages/71/ab/db4dbdfccb9ca874d6ce34a0784761471885d96730de85cee3d300381529/pybase64-1.4.2-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d377d48acf53abf4b926c2a7a24a19deb092f366a04ffd856bf4b3aa330b025d", size = 71608, upload-time = "2025-07-27T13:03:47.01Z" },
{ url = "https://files.pythonhosted.org/packages/f2/58/7f2cef1ceccc682088958448d56727369de83fa6b29148478f4d2acd107a/pybase64-1.4.2-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:ab9cdb6a8176a5cb967f53e6ad60e40c83caaa1ae31c5e1b29e5c8f507f17538", size = 56413, upload-time = "2025-07-27T13:03:49.908Z" },
@@ -5964,6 +5967,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/95/f0/c392c4ac8ccb7a34b28377c21faa2395313e3c676d76c382642e19a20703/pybase64-1.4.2-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:ad59362fc267bf15498a318c9e076686e4beeb0dfe09b457fabbc2b32468b97a", size = 58103, upload-time = "2025-07-27T13:04:29.996Z" },
{ url = "https://files.pythonhosted.org/packages/32/30/00ab21316e7df8f526aa3e3dc06f74de6711d51c65b020575d0105a025b2/pybase64-1.4.2-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:01593bd064e7dcd6c86d04e94e44acfe364049500c20ac68ca1e708fbb2ca970", size = 60779, upload-time = "2025-07-27T13:04:31.549Z" },
{ url = "https://files.pythonhosted.org/packages/a6/65/114ca81839b1805ce4a2b7d58bc16e95634734a2059991f6382fc71caf3e/pybase64-1.4.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5b81547ad8ea271c79fdf10da89a1e9313cb15edcba2a17adf8871735e9c02a0", size = 74684, upload-time = "2025-07-27T13:04:32.976Z" },
{ url = "https://files.pythonhosted.org/packages/99/bf/00a87d951473ce96c8c08af22b6983e681bfabdb78dd2dcf7ee58eac0932/pybase64-1.4.2-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:4157ad277a32cf4f02a975dffc62a3c67d73dfa4609b2c1978ef47e722b18b8e", size = 30924, upload-time = "2025-07-27T13:04:39.189Z" },
{ url = "https://files.pythonhosted.org/packages/ae/43/dee58c9d60e60e6fb32dc6da722d84592e22f13c277297eb4ce6baf99a99/pybase64-1.4.2-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:e113267dc349cf624eb4f4fbf53fd77835e1aa048ac6877399af426aab435757", size = 31390, upload-time = "2025-07-27T13:04:40.995Z" },
{ url = "https://files.pythonhosted.org/packages/e1/11/b28906fc2e330b8b1ab4bc845a7bef808b8506734e90ed79c6062b095112/pybase64-1.4.2-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:cea5aaf218fd9c5c23afacfe86fd4464dfedc1a0316dd3b5b4075b068cc67df0", size = 38212, upload-time = "2025-07-27T13:04:42.729Z" },
{ url = "https://files.pythonhosted.org/packages/e4/2e/851eb51284b97354ee5dfa1309624ab90920696e91a33cd85b13d20cc5c1/pybase64-1.4.2-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:a3e54dcf0d0305ec88473c9d0009f698cabf86f88a8a10090efeff2879c421bb", size = 71674, upload-time = "2025-07-27T13:04:49.294Z" },
{ url = "https://files.pythonhosted.org/packages/a4/8e/3479266bc0e65f6cc48b3938d4a83bff045330649869d950a378f2ddece0/pybase64-1.4.2-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:753da25d4fd20be7bda2746f545935773beea12d5cb5ec56ec2d2960796477b1", size = 56461, upload-time = "2025-07-27T13:04:52.37Z" },