mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
Better threaded PDF pipeline
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
f98c7e21dd
commit
8c905f3e70
@ -334,39 +334,29 @@ class ProcessingPipeline(str, Enum):
|
||||
ASR = "asr"
|
||||
|
||||
|
||||
class AsyncPdfPipelineOptions(PdfPipelineOptions):
|
||||
"""Enhanced options for async pipeline with cross-document batching"""
|
||||
class ThreadedPdfPipelineOptions(PdfPipelineOptions):
|
||||
"""Pipeline options for the threaded PDF pipeline with batching and backpressure control"""
|
||||
|
||||
# GPU batching configuration - larger than sync defaults
|
||||
layout_batch_size: int = 64
|
||||
ocr_batch_size: int = 32
|
||||
table_batch_size: int = 16
|
||||
# Batch sizes for different stages
|
||||
ocr_batch_size: int = 4
|
||||
layout_batch_size: int = 4
|
||||
table_batch_size: int = 4
|
||||
|
||||
# Async coordination
|
||||
# Timing control
|
||||
batch_timeout_seconds: float = 2.0
|
||||
max_concurrent_extractions: int = 16
|
||||
|
||||
# Queue sizes for backpressure
|
||||
extraction_queue_size: int = 100
|
||||
model_queue_size_multiplier: float = 2.0 # queue_size = batch_size * multiplier
|
||||
# Backpressure and queue control
|
||||
queue_max_size: int = 100
|
||||
max_workers: Optional[int] = None # None uses ThreadPoolExecutor default
|
||||
|
||||
# Resource management
|
||||
max_gpu_memory_mb: Optional[int] = None
|
||||
enable_resource_monitoring: bool = True
|
||||
|
||||
# Safety settings
|
||||
enable_exception_isolation: bool = True
|
||||
cleanup_validation: bool = True
|
||||
# Pipeline coordination
|
||||
stage_timeout_seconds: float = 10.0 # Timeout for feeding items to stages
|
||||
collection_timeout_seconds: float = 5.0 # Timeout for collecting results
|
||||
|
||||
@classmethod
|
||||
def from_sync_options(
|
||||
cls, sync_options: PdfPipelineOptions
|
||||
) -> "AsyncPdfPipelineOptions":
|
||||
"""Convert sync options to async options"""
|
||||
# Start with sync options and override with async defaults
|
||||
) -> "ThreadedPdfPipelineOptions":
|
||||
"""Convert sync options to threaded options"""
|
||||
data = sync_options.model_dump()
|
||||
|
||||
# Remove sync-specific fields if any
|
||||
data.pop("page_batch_size", None) # We don't use fixed page batching
|
||||
|
||||
return cls(**data)
|
||||
|
@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import AsyncIterable, Iterable, Iterator
|
||||
from collections.abc import Iterable, Iterator
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
@ -218,29 +217,29 @@ class DocumentConverter:
|
||||
@validate_call(config=ConfigDict(strict=True))
|
||||
def convert(
|
||||
self,
|
||||
source: Union[Path, str, DocumentStream],
|
||||
source: Union[Path, str, DocumentStream], # TODO review naming
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
) -> ConversionResult:
|
||||
for result in self.convert_all(
|
||||
all_res = self.convert_all(
|
||||
source=[source],
|
||||
headers=headers,
|
||||
raises_on_error=raises_on_error,
|
||||
max_num_pages=max_num_pages,
|
||||
max_file_size=max_file_size,
|
||||
headers=headers,
|
||||
page_range=page_range,
|
||||
):
|
||||
return result
|
||||
)
|
||||
return next(all_res)
|
||||
|
||||
@validate_call(config=ConfigDict(strict=True))
|
||||
def convert_all(
|
||||
self,
|
||||
source: Iterable[Union[Path, str, DocumentStream]],
|
||||
source: Iterable[Union[Path, str, DocumentStream]], # TODO review naming
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
raises_on_error: bool = True, # True: raises on first conversion error; False: does not raise on conv error
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
@ -251,10 +250,7 @@ class DocumentConverter:
|
||||
page_range=page_range,
|
||||
)
|
||||
conv_input = _DocumentConversionInput(
|
||||
path_or_stream_iterator=source,
|
||||
allowed_formats=self.allowed_formats,
|
||||
limits=limits,
|
||||
headers=headers,
|
||||
path_or_stream_iterator=source, limits=limits, headers=headers
|
||||
)
|
||||
conv_res_iter = self._convert(conv_input, raises_on_error=raises_on_error)
|
||||
|
||||
@ -276,107 +272,6 @@ class DocumentConverter:
|
||||
"Conversion failed because the provided file has no recognizable format or it wasn't in the list of allowed formats."
|
||||
)
|
||||
|
||||
async def convert_all_async(
|
||||
self,
|
||||
source: Iterable[Union[Path, str, DocumentStream]],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
) -> AsyncIterable[ConversionResult]:
|
||||
"""
|
||||
Async version of convert_all with cross-document batching.
|
||||
|
||||
Yields results as they complete, not necessarily in input order.
|
||||
"""
|
||||
limits = DocumentLimits(
|
||||
max_num_pages=max_num_pages,
|
||||
max_file_size=max_file_size,
|
||||
page_range=page_range,
|
||||
)
|
||||
conv_input = _DocumentConversionInput(
|
||||
path_or_stream_iterator=source, limits=limits, headers=headers
|
||||
)
|
||||
|
||||
# Create async document stream
|
||||
async def doc_stream():
|
||||
for doc in conv_input.docs(self.format_to_options):
|
||||
yield doc
|
||||
|
||||
# Check if we have async-capable pipelines
|
||||
has_async = False
|
||||
for format_opt in self.format_to_options.values():
|
||||
if hasattr(format_opt.pipeline_cls, "execute_stream"):
|
||||
has_async = True
|
||||
break
|
||||
|
||||
if has_async:
|
||||
# Use async pipeline for cross-document batching
|
||||
# For now, assume PDF pipeline handles all async processing
|
||||
pdf_format_opt = self.format_to_options.get(InputFormat.PDF)
|
||||
|
||||
if pdf_format_opt is None:
|
||||
return
|
||||
|
||||
pipeline_cls = pdf_format_opt.pipeline_cls
|
||||
if hasattr(pipeline_cls, "execute_stream"):
|
||||
# Initialize async pipeline
|
||||
pipeline_options = self.format_to_options[
|
||||
InputFormat.PDF
|
||||
].pipeline_options
|
||||
|
||||
# Convert to async options if needed
|
||||
from docling.datamodel.pipeline_options import AsyncPdfPipelineOptions
|
||||
|
||||
if not isinstance(pipeline_options, AsyncPdfPipelineOptions):
|
||||
pipeline_options = AsyncPdfPipelineOptions.from_sync_options(
|
||||
pipeline_options
|
||||
)
|
||||
|
||||
pipeline = pipeline_cls(pipeline_options)
|
||||
|
||||
# Process all documents through async pipeline
|
||||
async for result in pipeline.execute_stream(doc_stream()):
|
||||
yield result
|
||||
else:
|
||||
# Fallback to sequential async processing
|
||||
async for doc in doc_stream():
|
||||
result = await asyncio.to_thread(
|
||||
self._process_document, doc, raises_on_error
|
||||
)
|
||||
yield result
|
||||
else:
|
||||
# All pipelines are sync, process sequentially with threading
|
||||
async for doc in doc_stream():
|
||||
result = await asyncio.to_thread(
|
||||
self._process_document, doc, raises_on_error
|
||||
)
|
||||
yield result
|
||||
|
||||
async def convert_async(
|
||||
self,
|
||||
source: Union[Path, str, DocumentStream],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
) -> ConversionResult:
|
||||
"""Async convenience method for single document conversion."""
|
||||
async for result in self.convert_all_async(
|
||||
[source],
|
||||
headers=headers,
|
||||
raises_on_error=raises_on_error,
|
||||
max_num_pages=max_num_pages,
|
||||
max_file_size=max_file_size,
|
||||
page_range=page_range,
|
||||
):
|
||||
return result
|
||||
|
||||
# If no results were yielded, raise an error
|
||||
raise RuntimeError(f"No conversion result produced for source: {source}")
|
||||
|
||||
def _convert(
|
||||
self, conv_input: _DocumentConversionInput, raises_on_error: bool
|
||||
) -> Iterator[ConversionResult]:
|
||||
|
@ -148,72 +148,90 @@ class LayoutModel(BasePageModel):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
# Convert to list to allow multiple iterations
|
||||
pages = list(page_batch)
|
||||
|
||||
# Separate valid and invalid pages
|
||||
valid_pages = []
|
||||
valid_page_images = []
|
||||
|
||||
for page in pages:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
continue
|
||||
|
||||
assert page.size is not None
|
||||
page_image = page.get_image(scale=1.0)
|
||||
assert page_image is not None
|
||||
|
||||
valid_pages.append(page)
|
||||
valid_page_images.append(page_image)
|
||||
|
||||
# Process all valid pages with batch prediction
|
||||
batch_predictions = []
|
||||
if valid_page_images:
|
||||
with TimeRecorder(conv_res, "layout"):
|
||||
batch_predictions = self.layout_predictor.predict_batch( # type: ignore[attr-defined]
|
||||
valid_page_images
|
||||
)
|
||||
|
||||
# Process each page with its predictions
|
||||
valid_page_idx = 0
|
||||
for page in pages:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
with TimeRecorder(conv_res, "layout"):
|
||||
assert page.size is not None
|
||||
page_image = page.get_image(scale=1.0)
|
||||
assert page_image is not None
|
||||
continue
|
||||
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(
|
||||
self.layout_predictor.predict(page_image)
|
||||
):
|
||||
label = DocItemLabel(
|
||||
pred_item["label"]
|
||||
.lower()
|
||||
.replace(" ", "_")
|
||||
.replace("-", "_")
|
||||
) # Temporary, until docling-ibm-model uses docling-core types
|
||||
cluster = Cluster(
|
||||
id=ix,
|
||||
label=label,
|
||||
confidence=pred_item["confidence"],
|
||||
bbox=BoundingBox.model_validate(pred_item),
|
||||
cells=[],
|
||||
)
|
||||
clusters.append(cluster)
|
||||
page_predictions = batch_predictions[valid_page_idx]
|
||||
valid_page_idx += 1
|
||||
|
||||
if settings.debug.visualize_raw_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, clusters, mode_prefix="raw"
|
||||
)
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(page_predictions):
|
||||
label = DocItemLabel(
|
||||
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
|
||||
) # Temporary, until docling-ibm-model uses docling-core types
|
||||
cluster = Cluster(
|
||||
id=ix,
|
||||
label=label,
|
||||
confidence=pred_item["confidence"],
|
||||
bbox=BoundingBox.model_validate(pred_item),
|
||||
cells=[],
|
||||
)
|
||||
clusters.append(cluster)
|
||||
|
||||
# Apply postprocessing
|
||||
if settings.debug.visualize_raw_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, clusters, mode_prefix="raw"
|
||||
)
|
||||
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page, clusters, self.options
|
||||
).postprocess()
|
||||
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
||||
# Apply postprocessing
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page, clusters, self.options
|
||||
).postprocess()
|
||||
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"Mean of empty slice|invalid value encountered in scalar divide",
|
||||
RuntimeWarning,
|
||||
"numpy",
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"Mean of empty slice|invalid value encountered in scalar divide",
|
||||
RuntimeWarning,
|
||||
"numpy",
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].layout_score = float(
|
||||
np.mean([c.confidence for c in processed_clusters])
|
||||
)
|
||||
conv_res.confidence.pages[page.page_no].layout_score = float(
|
||||
np.mean([c.confidence for c in processed_clusters])
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].ocr_score = float(
|
||||
np.mean(
|
||||
[c.confidence for c in processed_cells if c.from_ocr]
|
||||
)
|
||||
)
|
||||
conv_res.confidence.pages[page.page_no].ocr_score = float(
|
||||
np.mean([c.confidence for c in processed_cells if c.from_ocr])
|
||||
)
|
||||
|
||||
page.predictions.layout = LayoutPrediction(
|
||||
clusters=processed_clusters
|
||||
)
|
||||
page.predictions.layout = LayoutPrediction(clusters=processed_clusters)
|
||||
|
||||
if settings.debug.visualize_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||
)
|
||||
if settings.debug.visualize_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||
)
|
||||
|
||||
yield page
|
||||
yield page
|
||||
|
@ -1,54 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AsyncIterable, Dict, Optional, Set, Tuple
|
||||
|
||||
from docling.backend.pdf_backend import PdfPageBackend
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import PipelineOptions
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTracker:
|
||||
"""Tracks document processing state for resource management"""
|
||||
|
||||
doc_id: str
|
||||
total_pages: int
|
||||
processed_pages: int = 0
|
||||
page_backends: Dict[int, PdfPageBackend] = field(
|
||||
default_factory=dict
|
||||
) # page_no -> backend
|
||||
conv_result: Optional[ConversionResult] = None
|
||||
|
||||
|
||||
class AsyncPipeline(ABC):
|
||||
"""Base class for async pipeline implementations"""
|
||||
|
||||
def __init__(self, pipeline_options: PipelineOptions):
|
||||
self.pipeline_options = pipeline_options
|
||||
self.keep_images = False
|
||||
self.keep_backend = False
|
||||
|
||||
@abstractmethod
|
||||
async def execute_stream(
|
||||
self, input_docs: AsyncIterable[InputDocument]
|
||||
) -> AsyncIterable[ConversionResult]:
|
||||
"""Process multiple documents with cross-document batching"""
|
||||
|
||||
async def execute_single(
|
||||
self, in_doc: InputDocument, raises_on_error: bool = True
|
||||
) -> ConversionResult:
|
||||
"""Process a single document - for backward compatibility"""
|
||||
|
||||
async def single_doc_stream():
|
||||
yield in_doc
|
||||
|
||||
async for result in self.execute_stream(single_doc_stream()):
|
||||
return result
|
||||
|
||||
# Should never reach here
|
||||
raise RuntimeError("No result produced for document")
|
@ -1,433 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, AsyncIterable, Dict, List, Optional, Tuple
|
||||
|
||||
from docling.datamodel.base_models import ConversionStatus, Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import AsyncPdfPipelineOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||
from docling.models.document_picture_classifier import (
|
||||
DocumentPictureClassifier,
|
||||
DocumentPictureClassifierOptions,
|
||||
)
|
||||
from docling.models.factories import get_ocr_factory, get_picture_description_factory
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
|
||||
from docling.models.page_preprocessing_model import (
|
||||
PagePreprocessingModel,
|
||||
PagePreprocessingOptions,
|
||||
)
|
||||
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.pipeline.async_base_pipeline import AsyncPipeline
|
||||
from docling.pipeline.graph import GraphRunner, get_pipeline_thread_pool
|
||||
from docling.pipeline.resource_manager import AsyncPageTracker
|
||||
from docling.pipeline.stages import (
|
||||
AggregationStage,
|
||||
BatchProcessorStage,
|
||||
ExtractionStage,
|
||||
PageProcessorStage,
|
||||
SinkStage,
|
||||
SourceStage,
|
||||
)
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
"""
|
||||
An async, graph-based pipeline for processing PDFs with cross-document batching.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline_options: AsyncPdfPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.pipeline_options: AsyncPdfPipelineOptions = pipeline_options
|
||||
self.page_tracker = AsyncPageTracker(
|
||||
keep_images=self._should_keep_images(),
|
||||
keep_backend=self._should_keep_backend(),
|
||||
)
|
||||
# Get shared thread pool for enrichment operations
|
||||
self._thread_pool = get_pipeline_thread_pool()
|
||||
self._initialize_models()
|
||||
|
||||
def _should_keep_images(self) -> bool:
|
||||
return (
|
||||
self.pipeline_options.generate_page_images
|
||||
or self.pipeline_options.generate_picture_images
|
||||
or self.pipeline_options.generate_table_images
|
||||
)
|
||||
|
||||
def _should_keep_backend(self) -> bool:
|
||||
return (
|
||||
self.pipeline_options.do_formula_enrichment
|
||||
or self.pipeline_options.do_code_enrichment
|
||||
or self.pipeline_options.do_picture_classification
|
||||
or self.pipeline_options.do_picture_description
|
||||
)
|
||||
|
||||
def _initialize_models(self):
|
||||
artifacts_path = self._get_artifacts_path()
|
||||
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
|
||||
self.preprocessing_model = PagePreprocessingModel(
|
||||
options=PagePreprocessingOptions(
|
||||
images_scale=self.pipeline_options.images_scale,
|
||||
)
|
||||
)
|
||||
self.ocr_model = self._get_ocr_model(artifacts_path)
|
||||
self.layout_model = LayoutModel(
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
options=self.pipeline_options.layout_options,
|
||||
)
|
||||
self.table_model = TableStructureModel(
|
||||
enabled=self.pipeline_options.do_table_structure,
|
||||
artifacts_path=artifacts_path,
|
||||
options=self.pipeline_options.table_structure_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
self.assemble_model = PageAssembleModel(options=PageAssembleOptions())
|
||||
self.code_formula_model = CodeFormulaModel(
|
||||
enabled=self.pipeline_options.do_code_enrichment
|
||||
or self.pipeline_options.do_formula_enrichment,
|
||||
artifacts_path=artifacts_path,
|
||||
options=CodeFormulaModelOptions(
|
||||
do_code_enrichment=self.pipeline_options.do_code_enrichment,
|
||||
do_formula_enrichment=self.pipeline_options.do_formula_enrichment,
|
||||
),
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
self.picture_classifier = DocumentPictureClassifier(
|
||||
enabled=self.pipeline_options.do_picture_classification,
|
||||
artifacts_path=artifacts_path,
|
||||
options=DocumentPictureClassifierOptions(),
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
self.picture_description_model = self._get_picture_description_model(
|
||||
artifacts_path
|
||||
)
|
||||
|
||||
def _get_artifacts_path(self) -> Optional[str]:
|
||||
from pathlib import Path
|
||||
|
||||
artifacts_path = None
|
||||
if self.pipeline_options.artifacts_path is not None:
|
||||
artifacts_path = Path(self.pipeline_options.artifacts_path).expanduser()
|
||||
elif settings.artifacts_path is not None:
|
||||
artifacts_path = Path(settings.artifacts_path).expanduser()
|
||||
|
||||
if artifacts_path is not None and not artifacts_path.is_dir():
|
||||
raise RuntimeError(
|
||||
f"The value of {artifacts_path=} is not valid. "
|
||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||
)
|
||||
return artifacts_path
|
||||
|
||||
def _get_ocr_model(self, artifacts_path: Optional[str] = None) -> BaseOcrModel:
|
||||
factory = get_ocr_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
return factory.create_instance(
|
||||
options=self.pipeline_options.ocr_options,
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
def _get_picture_description_model(self, artifacts_path: Optional[str] = None):
|
||||
factory = get_picture_description_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
return factory.create_instance(
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
enabled=self.pipeline_options.do_picture_description,
|
||||
enable_remote_services=self.pipeline_options.enable_remote_services,
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
async def execute_stream(
|
||||
self, input_docs: AsyncIterable[InputDocument]
|
||||
) -> AsyncIterable[ConversionResult]:
|
||||
"""Main async processing driven by a pipeline graph."""
|
||||
stages = [
|
||||
SourceStage("source"),
|
||||
ExtractionStage(
|
||||
"extractor",
|
||||
self.page_tracker,
|
||||
self.pipeline_options.max_concurrent_extractions,
|
||||
),
|
||||
PageProcessorStage("preprocessor", self.preprocessing_model),
|
||||
BatchProcessorStage(
|
||||
"ocr",
|
||||
self.ocr_model,
|
||||
self.pipeline_options.ocr_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
),
|
||||
BatchProcessorStage(
|
||||
"layout",
|
||||
self.layout_model,
|
||||
self.pipeline_options.layout_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
),
|
||||
BatchProcessorStage(
|
||||
"table",
|
||||
self.table_model,
|
||||
self.pipeline_options.table_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
),
|
||||
PageProcessorStage("assembler", self.assemble_model),
|
||||
AggregationStage("aggregator", self.page_tracker, self._finalize_document),
|
||||
SinkStage("sink"),
|
||||
]
|
||||
|
||||
edges = [
|
||||
# Main processing path
|
||||
{
|
||||
"from_stage": "source",
|
||||
"from_output": "out",
|
||||
"to_stage": "extractor",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "extractor",
|
||||
"from_output": "out",
|
||||
"to_stage": "preprocessor",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "preprocessor",
|
||||
"from_output": "out",
|
||||
"to_stage": "ocr",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "ocr",
|
||||
"from_output": "out",
|
||||
"to_stage": "layout",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "layout",
|
||||
"from_output": "out",
|
||||
"to_stage": "table",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "table",
|
||||
"from_output": "out",
|
||||
"to_stage": "assembler",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "assembler",
|
||||
"from_output": "out",
|
||||
"to_stage": "aggregator",
|
||||
"to_input": "in",
|
||||
},
|
||||
# Failure path
|
||||
{
|
||||
"from_stage": "extractor",
|
||||
"from_output": "fail",
|
||||
"to_stage": "aggregator",
|
||||
"to_input": "fail",
|
||||
},
|
||||
# Final output
|
||||
{
|
||||
"from_stage": "aggregator",
|
||||
"from_output": "out",
|
||||
"to_stage": "sink",
|
||||
"to_input": "in",
|
||||
},
|
||||
]
|
||||
|
||||
runner = GraphRunner(stages, edges)
|
||||
source_config = {"stage": "source", "channel": "out"}
|
||||
sink_config = {"stage": "sink", "channel": "in"}
|
||||
|
||||
try:
|
||||
async for result in runner.run(
|
||||
input_docs,
|
||||
source_config,
|
||||
sink_config,
|
||||
self.pipeline_options.extraction_queue_size,
|
||||
):
|
||||
yield result
|
||||
except* Exception as eg:
|
||||
_log.error(f"Pipeline failed with exceptions: {eg.exceptions}")
|
||||
raise (eg.exceptions[0] if eg.exceptions else RuntimeError("Unknown error"))
|
||||
finally:
|
||||
await self.page_tracker.cleanup_all()
|
||||
|
||||
async def _finalize_document(self, conv_res: ConversionResult) -> None:
|
||||
"""Finalize a complete document (same as StandardPdfPipeline._assemble_document)"""
|
||||
# This matches the logic from StandardPdfPipeline
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from docling.datamodel.base_models import AssembledUnit
|
||||
|
||||
all_elements = []
|
||||
all_headers = []
|
||||
all_body = []
|
||||
|
||||
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
|
||||
for p in conv_res.pages:
|
||||
if p.assembled is not None:
|
||||
for el in p.assembled.body:
|
||||
all_body.append(el)
|
||||
for el in p.assembled.headers:
|
||||
all_headers.append(el)
|
||||
for el in p.assembled.elements:
|
||||
all_elements.append(el)
|
||||
|
||||
conv_res.assembled = AssembledUnit(
|
||||
elements=all_elements, headers=all_headers, body=all_body
|
||||
)
|
||||
|
||||
conv_res.document = self.reading_order_model(conv_res)
|
||||
|
||||
# Generate page images in the output
|
||||
if self.pipeline_options.generate_page_images:
|
||||
for page in conv_res.pages:
|
||||
if page.image is not None:
|
||||
page_no = page.page_no + 1
|
||||
from docling_core.types.doc import ImageRef
|
||||
|
||||
conv_res.document.pages[page_no].image = ImageRef.from_pil(
|
||||
page.image, dpi=int(72 * self.pipeline_options.images_scale)
|
||||
)
|
||||
|
||||
# Handle picture/table images (same as StandardPdfPipeline)
|
||||
self._generate_element_images(conv_res)
|
||||
|
||||
# Aggregate confidence values
|
||||
self._aggregate_confidence(conv_res)
|
||||
|
||||
# Run enrichment pipeline
|
||||
await self._enrich_document(conv_res)
|
||||
|
||||
# Set final status
|
||||
conv_res.status = self._determine_status(conv_res)
|
||||
|
||||
def _generate_element_images(self, conv_res: ConversionResult) -> None:
|
||||
"""Generate images for elements (same as StandardPdfPipeline)"""
|
||||
import warnings
|
||||
|
||||
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
if (
|
||||
self.pipeline_options.generate_picture_images
|
||||
or self.pipeline_options.generate_table_images
|
||||
):
|
||||
scale = self.pipeline_options.images_scale
|
||||
for element, _level in conv_res.document.iterate_items():
|
||||
if not isinstance(element, DocItem) or len(element.prov) == 0:
|
||||
continue
|
||||
if (
|
||||
isinstance(element, PictureItem)
|
||||
and self.pipeline_options.generate_picture_images
|
||||
) or (
|
||||
isinstance(element, TableItem)
|
||||
and self.pipeline_options.generate_table_images
|
||||
):
|
||||
page_ix = element.prov[0].page_no - 1
|
||||
page = next(
|
||||
(p for p in conv_res.pages if p.page_no == page_ix), None
|
||||
)
|
||||
if (
|
||||
page is not None
|
||||
and page.size is not None
|
||||
and page.image is not None
|
||||
):
|
||||
crop_bbox = (
|
||||
element.prov[0]
|
||||
.bbox.scaled(scale=scale)
|
||||
.to_top_left_origin(
|
||||
page_height=page.size.height * scale
|
||||
)
|
||||
)
|
||||
cropped_im = page.image.crop(crop_bbox.as_tuple())
|
||||
element.image = ImageRef.from_pil(
|
||||
cropped_im, dpi=int(72 * scale)
|
||||
)
|
||||
|
||||
def _aggregate_confidence(self, conv_res: ConversionResult) -> None:
|
||||
"""Aggregate confidence scores (same as StandardPdfPipeline)"""
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
if len(conv_res.pages) > 0:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=RuntimeWarning,
|
||||
message="Mean of empty slice|All-NaN slice encountered",
|
||||
)
|
||||
conv_res.confidence.layout_score = float(
|
||||
np.nanmean(
|
||||
[c.layout_score for c in conv_res.confidence.pages.values()]
|
||||
)
|
||||
)
|
||||
conv_res.confidence.parse_score = float(
|
||||
np.nanquantile(
|
||||
[c.parse_score for c in conv_res.confidence.pages.values()],
|
||||
q=0.1,
|
||||
)
|
||||
)
|
||||
conv_res.confidence.table_score = float(
|
||||
np.nanmean(
|
||||
[c.table_score for c in conv_res.confidence.pages.values()]
|
||||
)
|
||||
)
|
||||
conv_res.confidence.ocr_score = float(
|
||||
np.nanmean(
|
||||
[c.ocr_score for c in conv_res.confidence.pages.values()]
|
||||
)
|
||||
)
|
||||
|
||||
async def _enrich_document(self, conv_res: ConversionResult) -> None:
|
||||
"""Run enrichment models on document"""
|
||||
# Run enrichment models (same as base pipeline but async)
|
||||
from docling.utils.utils import chunkify
|
||||
|
||||
enrichment_models = [
|
||||
self.code_formula_model,
|
||||
self.picture_classifier,
|
||||
self.picture_description_model,
|
||||
]
|
||||
|
||||
for model in enrichment_models:
|
||||
if model is None or not getattr(model, "enabled", True):
|
||||
continue
|
||||
|
||||
# Prepare elements
|
||||
elements_to_process = []
|
||||
for doc_element, _level in conv_res.document.iterate_items():
|
||||
prepared = model.prepare_element(conv_res=conv_res, element=doc_element)
|
||||
if prepared is not None:
|
||||
elements_to_process.append(prepared)
|
||||
|
||||
# Process in batches
|
||||
for element_batch in chunkify(
|
||||
elements_to_process, model.elements_batch_size
|
||||
):
|
||||
# Run model in shared thread pool to avoid blocking
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool,
|
||||
lambda: list(model(conv_res.document, element_batch)),
|
||||
)
|
||||
|
||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||
"""Determine conversion status"""
|
||||
# Simple implementation - could be enhanced
|
||||
if conv_res.pages and conv_res.document:
|
||||
return ConversionStatus.SUCCESS
|
||||
else:
|
||||
return ConversionStatus.FAILURE
|
@ -1,180 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterable, Dict, List, Literal, Optional
|
||||
|
||||
# Sentinel to signal stream completion
|
||||
STOP_SENTINEL = object()
|
||||
|
||||
# Global thread pool for pipeline operations - shared across all stages
|
||||
_PIPELINE_THREAD_POOL: Optional[ThreadPoolExecutor] = None
|
||||
_THREAD_POOL_REFS = weakref.WeakSet()
|
||||
|
||||
|
||||
def get_pipeline_thread_pool(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
|
||||
"""Get or create the shared pipeline thread pool."""
|
||||
global _PIPELINE_THREAD_POOL
|
||||
if _PIPELINE_THREAD_POOL is None or _PIPELINE_THREAD_POOL._shutdown:
|
||||
_PIPELINE_THREAD_POOL = ThreadPoolExecutor(
|
||||
max_workers=max_workers, thread_name_prefix="docling_pipeline"
|
||||
)
|
||||
_THREAD_POOL_REFS.add(_PIPELINE_THREAD_POOL)
|
||||
return _PIPELINE_THREAD_POOL
|
||||
|
||||
|
||||
def shutdown_pipeline_thread_pool(wait: bool = True) -> None:
|
||||
"""Shutdown the shared thread pool."""
|
||||
global _PIPELINE_THREAD_POOL
|
||||
if _PIPELINE_THREAD_POOL is not None:
|
||||
_PIPELINE_THREAD_POOL.shutdown(wait=wait)
|
||||
_PIPELINE_THREAD_POOL = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StreamItem:
|
||||
"""
|
||||
A wrapper for data flowing through the pipeline, maintaining a link
|
||||
to the original conversion result context.
|
||||
"""
|
||||
|
||||
payload: Any
|
||||
conv_res_id: int
|
||||
conv_res: Any # Opaque reference to ConversionResult
|
||||
|
||||
|
||||
class PipelineStage(ABC):
|
||||
"""A single, encapsulated step in a processing pipeline graph."""
|
||||
|
||||
def __init__(self, name: str, max_workers: Optional[int] = None):
|
||||
self.name = name
|
||||
self.input_queues: Dict[str, asyncio.Queue] = {}
|
||||
self.output_queues: Dict[str, List[asyncio.Queue]] = {}
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread_pool = get_pipeline_thread_pool(max_workers)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
The core execution logic for the stage. This method is responsible for
|
||||
consuming from input queues, processing data, and putting results into
|
||||
output queues.
|
||||
"""
|
||||
|
||||
async def _send_to_outputs(self, channel: str, items: List[StreamItem] | List[Any]):
|
||||
"""Helper to send processed items to all connected output queues."""
|
||||
if channel in self.output_queues:
|
||||
for queue in self.output_queues[channel]:
|
||||
for item in items:
|
||||
await queue.put(item)
|
||||
|
||||
async def _signal_downstream_completion(self):
|
||||
"""Signal that this stage is done processing to all output channels."""
|
||||
for channel_queues in self.output_queues.values():
|
||||
for queue in channel_queues:
|
||||
await queue.put(STOP_SENTINEL)
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
return self._loop
|
||||
|
||||
@property
|
||||
def thread_pool(self) -> ThreadPoolExecutor:
|
||||
"""Get the shared thread pool for this stage."""
|
||||
return self._thread_pool
|
||||
|
||||
|
||||
class GraphRunner:
|
||||
"""Connects stages and runs the pipeline graph."""
|
||||
|
||||
def __init__(self, stages: List[PipelineStage], edges: List[Dict[str, str]]):
|
||||
self._stages = {s.name: s for s in stages}
|
||||
self._edges = edges
|
||||
|
||||
def _wire_graph(self, queue_max_size: int):
|
||||
"""Create queues for edges and connect them to stage inputs and outputs."""
|
||||
for edge in self._edges:
|
||||
from_stage, from_output = edge["from_stage"], edge["from_output"]
|
||||
to_stage, to_input = edge["to_stage"], edge["to_input"]
|
||||
|
||||
queue = asyncio.Queue(maxsize=queue_max_size)
|
||||
|
||||
# Connect to source stage's output
|
||||
self._stages[from_stage].output_queues.setdefault(from_output, []).append(
|
||||
queue
|
||||
)
|
||||
|
||||
# Connect to destination stage's input
|
||||
self._stages[to_stage].input_queues[to_input] = queue
|
||||
|
||||
async def _run_source(
|
||||
self,
|
||||
source_stream: AsyncIterable[Any],
|
||||
source_stage: str,
|
||||
source_channel: str,
|
||||
):
|
||||
"""Feed the graph from an external async iterable."""
|
||||
output_queues = self._stages[source_stage].output_queues.get(source_channel, [])
|
||||
async for item in source_stream:
|
||||
for queue in output_queues:
|
||||
await queue.put(item)
|
||||
# Signal completion to all downstream queues
|
||||
for queue in output_queues:
|
||||
await queue.put(STOP_SENTINEL)
|
||||
|
||||
async def _run_sink(self, sink_stage: str, sink_channel: str) -> AsyncIterable[Any]:
|
||||
"""Yield results from the graph's final output queue."""
|
||||
queue = self._stages[sink_stage].input_queues[sink_channel]
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is STOP_SENTINEL:
|
||||
break
|
||||
yield item
|
||||
await queue.put(STOP_SENTINEL) # Allow other sinks to terminate
|
||||
|
||||
async def run(
|
||||
self,
|
||||
source_stream: AsyncIterable,
|
||||
source_config: Dict[str, str],
|
||||
sink_config: Dict[str, str],
|
||||
queue_max_size: int = 32,
|
||||
) -> AsyncIterable:
|
||||
"""
|
||||
Executes the entire pipeline graph.
|
||||
|
||||
Args:
|
||||
source_stream: The initial async iterable to feed the graph.
|
||||
source_config: Dictionary with "stage" and "channel" for the entry point.
|
||||
sink_config: Dictionary with "stage" and "channel" for the exit point.
|
||||
queue_max_size: The max size for the internal asyncio.Queues.
|
||||
"""
|
||||
self._wire_graph(queue_max_size)
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
# Create a task for the source feeder
|
||||
tg.create_task(
|
||||
self._run_source(
|
||||
source_stream, source_config["stage"], source_config["channel"]
|
||||
)
|
||||
)
|
||||
|
||||
# Create tasks for all pipeline stages
|
||||
for stage in self._stages.values():
|
||||
tg.create_task(stage.run())
|
||||
|
||||
# Yield results from the sink
|
||||
async for result in self._run_sink(
|
||||
sink_config["stage"], sink_config["channel"]
|
||||
):
|
||||
yield result
|
||||
finally:
|
||||
# Ensure thread pool cleanup on pipeline completion
|
||||
# Note: We don't shutdown here as other pipelines might be using it
|
||||
pass
|
@ -1,101 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.pipeline.async_base_pipeline import DocumentTracker
|
||||
from docling.pipeline.graph import get_pipeline_thread_pool
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncPageTracker:
|
||||
"""Manages page backend lifecycle across documents"""
|
||||
|
||||
_doc_trackers: Dict[str, DocumentTracker] = field(default_factory=dict)
|
||||
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
keep_images: bool = False
|
||||
keep_backend: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize shared thread pool reference after dataclass creation"""
|
||||
self._thread_pool = get_pipeline_thread_pool()
|
||||
|
||||
async def register_document(
|
||||
self, conv_res: ConversionResult, total_pages: int
|
||||
) -> str:
|
||||
"""Register a new document for tracking"""
|
||||
async with self._lock:
|
||||
# Use UUID for better collision resistance than str(id())
|
||||
doc_id = str(uuid.uuid4())
|
||||
self._doc_trackers[doc_id] = DocumentTracker(
|
||||
doc_id=doc_id, total_pages=total_pages, conv_result=conv_res
|
||||
)
|
||||
# Store the doc_id in the conv_res for later lookup
|
||||
conv_res._async_doc_id = doc_id
|
||||
return doc_id
|
||||
|
||||
async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None:
|
||||
"""Track when a page backend is loaded"""
|
||||
async with self._lock:
|
||||
doc_id = getattr(conv_res, "_async_doc_id", None)
|
||||
if doc_id and doc_id in self._doc_trackers and page._backend is not None:
|
||||
self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend
|
||||
|
||||
async def track_page_completion(
|
||||
self, page: Page, conv_res: ConversionResult
|
||||
) -> bool:
|
||||
"""Track page completion and cleanup when all pages done"""
|
||||
async with self._lock:
|
||||
doc_id = getattr(conv_res, "_async_doc_id", None)
|
||||
if not doc_id or doc_id not in self._doc_trackers:
|
||||
_log.warning(f"Document {doc_id} not registered for tracking")
|
||||
return False
|
||||
|
||||
tracker = self._doc_trackers[doc_id]
|
||||
tracker.processed_pages += 1
|
||||
|
||||
# Clear this page's image cache if needed
|
||||
if not self.keep_images:
|
||||
page._image_cache = {}
|
||||
|
||||
# If all pages from this document are processed, cleanup
|
||||
if tracker.processed_pages == tracker.total_pages:
|
||||
await self._cleanup_document_resources(tracker)
|
||||
del self._doc_trackers[doc_id]
|
||||
# Clean up the doc_id from conv_res
|
||||
if hasattr(conv_res, "_async_doc_id"):
|
||||
delattr(conv_res, "_async_doc_id")
|
||||
return True # Document is complete
|
||||
|
||||
return False # Document is not yet complete
|
||||
|
||||
async def _cleanup_document_resources(self, tracker: DocumentTracker) -> None:
|
||||
"""Cleanup all resources for a completed document"""
|
||||
if not self.keep_backend:
|
||||
# Unload all page backends for this document
|
||||
for page_no, backend in tracker.page_backends.items():
|
||||
if backend is not None:
|
||||
try:
|
||||
# Run unload in shared thread pool to avoid blocking
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, backend.unload
|
||||
)
|
||||
except Exception as e:
|
||||
_log.warning(
|
||||
f"Failed to unload backend for page {page_no}: {e}"
|
||||
)
|
||||
|
||||
tracker.page_backends.clear()
|
||||
_log.debug(f"Cleaned up resources for document {tracker.doc_id}")
|
||||
|
||||
async def cleanup_all(self) -> None:
|
||||
"""Cleanup all tracked documents - for shutdown"""
|
||||
async with self._lock:
|
||||
for tracker in self._doc_trackers.values():
|
||||
await self._cleanup_document_resources(tracker)
|
||||
self._doc_trackers.clear()
|
@ -1,300 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
|
||||
|
||||
from docling.datamodel.document import ConversionResult, InputDocument, Page
|
||||
from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SourceStage(PipelineStage):
|
||||
"""A placeholder stage to represent the entry point of the graph."""
|
||||
|
||||
async def run(self) -> None:
|
||||
# This stage is driven by the GraphRunner's _run_source method
|
||||
# and does not have its own execution loop.
|
||||
pass
|
||||
|
||||
|
||||
class SinkStage(PipelineStage):
|
||||
"""A placeholder stage to represent the exit point of the graph."""
|
||||
|
||||
async def run(self) -> None:
|
||||
# This stage is read by the GraphRunner's _run_sink method
|
||||
# and does not have its own execution loop.
|
||||
pass
|
||||
|
||||
|
||||
class ExtractionStage(PipelineStage):
|
||||
"""Extracts pages from documents and tracks them."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
page_tracker: Any,
|
||||
max_concurrent_extractions: int,
|
||||
):
|
||||
super().__init__(name)
|
||||
self.page_tracker = page_tracker
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent_extractions)
|
||||
self.input_channel = "in"
|
||||
self.output_channel = "out"
|
||||
self.failure_channel = "fail"
|
||||
|
||||
async def _extract_page(
|
||||
self, page_no: int, conv_res: ConversionResult
|
||||
) -> StreamItem | None:
|
||||
"""Coroutine to extract a single page."""
|
||||
try:
|
||||
async with self.semaphore:
|
||||
page = Page(page_no=page_no)
|
||||
conv_res.pages.append(page)
|
||||
|
||||
page._backend = await self.loop.run_in_executor(
|
||||
self.thread_pool, conv_res.input._backend.load_page, page_no
|
||||
)
|
||||
|
||||
if page._backend and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
await self.page_tracker.track_page_loaded(page, conv_res)
|
||||
return StreamItem(
|
||||
payload=page, conv_res_id=id(conv_res), conv_res=conv_res
|
||||
)
|
||||
else:
|
||||
_log.warning(
|
||||
f"Failed to load or validate page {page_no} from document {conv_res.input.file.name}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
_log.error(
|
||||
f"Error extracting page {page_no} from document {conv_res.input.file.name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Don't propagate individual page failures - document-level error handling will catch this
|
||||
return None
|
||||
|
||||
async def _process_document(self, in_doc: InputDocument):
|
||||
"""Processes a single document, extracting all its pages."""
|
||||
conv_res = ConversionResult(input=in_doc)
|
||||
|
||||
try:
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
|
||||
if not isinstance(in_doc._backend, PdfDocumentBackend):
|
||||
raise TypeError("Backend is not a valid PdfDocumentBackend")
|
||||
|
||||
total_pages = in_doc.page_count
|
||||
await self.page_tracker.register_document(conv_res, total_pages)
|
||||
|
||||
start_page, end_page = conv_res.input.limits.page_range
|
||||
page_indices_to_extract = [
|
||||
i for i in range(total_pages) if (start_page - 1) <= i <= (end_page - 1)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
self.loop.create_task(self._extract_page(i, conv_res))
|
||||
for i in page_indices_to_extract
|
||||
]
|
||||
pages_extracted = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Filter out None results and exceptions, log any exceptions found
|
||||
valid_pages = []
|
||||
for i, result in enumerate(pages_extracted):
|
||||
if isinstance(result, Exception):
|
||||
_log.error(
|
||||
f"Page extraction failed for page {page_indices_to_extract[i]} "
|
||||
f"in document {in_doc.file.name}: {result}"
|
||||
)
|
||||
elif result is not None:
|
||||
valid_pages.append(result)
|
||||
|
||||
await self._send_to_outputs(self.output_channel, valid_pages)
|
||||
|
||||
# If no pages were successfully extracted, mark as failure
|
||||
if not valid_pages:
|
||||
_log.error(
|
||||
f"No pages could be extracted from document {in_doc.file.name}"
|
||||
)
|
||||
conv_res.status = "FAILURE"
|
||||
await self._send_to_outputs(self.failure_channel, [conv_res])
|
||||
|
||||
except Exception as e:
|
||||
_log.error(
|
||||
f"Document-level extraction failed for {in_doc.file.name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
conv_res.status = "FAILURE"
|
||||
await self._send_to_outputs(self.failure_channel, [conv_res])
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Main loop to consume documents and launch extraction tasks."""
|
||||
q_in = self.input_queues[self.input_channel]
|
||||
while True:
|
||||
doc = await q_in.get()
|
||||
if doc is STOP_SENTINEL:
|
||||
await self._signal_downstream_completion()
|
||||
break
|
||||
await self._process_document(doc)
|
||||
|
||||
|
||||
class PageProcessorStage(PipelineStage):
|
||||
"""Applies a synchronous, 1-to-1 processing function to each page."""
|
||||
|
||||
def __init__(self, name: str, model: Any):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
self.input_channel = "in"
|
||||
self.output_channel = "out"
|
||||
|
||||
async def run(self) -> None:
|
||||
q_in = self.input_queues[self.input_channel]
|
||||
while True:
|
||||
item = await q_in.get()
|
||||
if item is STOP_SENTINEL:
|
||||
await self._signal_downstream_completion()
|
||||
break
|
||||
|
||||
# The model call is sync, run in thread to avoid blocking event loop
|
||||
processed_page = await self.loop.run_in_executor(
|
||||
self.thread_pool,
|
||||
lambda: next(iter(self.model(item.conv_res, [item.payload]))),
|
||||
)
|
||||
item.payload = processed_page
|
||||
await self._send_to_outputs(self.output_channel, [item])
|
||||
|
||||
|
||||
class BatchProcessorStage(PipelineStage):
|
||||
"""Batches items and applies a synchronous model to the batch."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model: Any,
|
||||
batch_size: int,
|
||||
batch_timeout: float,
|
||||
):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.batch_timeout = batch_timeout
|
||||
self.input_channel = "in"
|
||||
self.output_channel = "out"
|
||||
|
||||
async def _collect_batch(self, q_in: asyncio.Queue) -> List[StreamItem] | None:
|
||||
"""Collects a batch of items from the input queue with a timeout."""
|
||||
try:
|
||||
# Wait for the first item without a timeout
|
||||
first_item = await q_in.get()
|
||||
if first_item is STOP_SENTINEL:
|
||||
return None # End of stream
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
|
||||
batch = [first_item]
|
||||
start_time = self.loop.time()
|
||||
|
||||
while len(batch) < self.batch_size:
|
||||
timeout = self.batch_timeout - (self.loop.time() - start_time)
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
item = await asyncio.wait_for(q_in.get(), timeout)
|
||||
if item is STOP_SENTINEL:
|
||||
# Put sentinel back for other potential consumers or the main loop
|
||||
await q_in.put(STOP_SENTINEL)
|
||||
break
|
||||
batch.append(item)
|
||||
except asyncio.TimeoutError:
|
||||
break # Batching timeout reached
|
||||
return batch
|
||||
|
||||
async def run(self) -> None:
|
||||
q_in = self.input_queues[self.input_channel]
|
||||
while True:
|
||||
batch = await self._collect_batch(q_in)
|
||||
|
||||
if not batch: # This can be None or an empty list.
|
||||
await self._signal_downstream_completion()
|
||||
break
|
||||
|
||||
# Group pages by their original ConversionResult
|
||||
grouped_by_doc = defaultdict(list)
|
||||
for item in batch:
|
||||
grouped_by_doc[item.conv_res_id].append(item)
|
||||
|
||||
processed_items = []
|
||||
for conv_res_id, items in grouped_by_doc.items():
|
||||
conv_res = items[0].conv_res
|
||||
pages = [item.payload for item in items]
|
||||
|
||||
# The model call is sync, run in thread
|
||||
processed_pages = await self.loop.run_in_executor(
|
||||
self.thread_pool, lambda: list(self.model(conv_res, pages))
|
||||
)
|
||||
|
||||
# Re-wrap the processed pages into StreamItems
|
||||
for i, page in enumerate(processed_pages):
|
||||
processed_items.append(
|
||||
StreamItem(
|
||||
payload=page,
|
||||
conv_res_id=items[i].conv_res_id,
|
||||
conv_res=items[i].conv_res,
|
||||
)
|
||||
)
|
||||
|
||||
await self._send_to_outputs(self.output_channel, processed_items)
|
||||
|
||||
|
||||
class AggregationStage(PipelineStage):
|
||||
"""Aggregates processed pages back into completed documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
page_tracker: Any,
|
||||
finalizer_func: Callable[[ConversionResult], Coroutine],
|
||||
):
|
||||
super().__init__(name)
|
||||
self.page_tracker = page_tracker
|
||||
self.finalizer_func = finalizer_func
|
||||
self.success_channel = "in"
|
||||
self.failure_channel = "fail"
|
||||
self.output_channel = "out"
|
||||
|
||||
async def run(self) -> None:
|
||||
success_q = self.input_queues[self.success_channel]
|
||||
failure_q = self.input_queues.get(self.failure_channel)
|
||||
|
||||
async def handle_successes():
|
||||
while True:
|
||||
item = await success_q.get()
|
||||
if item is STOP_SENTINEL:
|
||||
break
|
||||
is_doc_complete = await self.page_tracker.track_page_completion(
|
||||
item.payload, item.conv_res
|
||||
)
|
||||
if is_doc_complete:
|
||||
await self.finalizer_func(item.conv_res)
|
||||
await self._send_to_outputs(self.output_channel, [item.conv_res])
|
||||
|
||||
async def handle_failures():
|
||||
if failure_q is None:
|
||||
return # No failure channel, nothing to do
|
||||
while True:
|
||||
failed_res = await failure_q.get()
|
||||
if failed_res is STOP_SENTINEL:
|
||||
break
|
||||
await self._send_to_outputs(self.output_channel, [failed_res])
|
||||
|
||||
# Create tasks only for channels that exist
|
||||
tasks = [handle_successes()]
|
||||
if failure_q is not None:
|
||||
tasks.append(handle_failures())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
await self._signal_downstream_completion()
|
835
docling/pipeline/threaded_standard_pdf_pipeline.py
Normal file
835
docling/pipeline/threaded_standard_pdf_pipeline.py
Normal file
@ -0,0 +1,835 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Protocol, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
|
||||
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import AssembledUnit, ConversionStatus, Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import ThreadedPdfPipelineOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||
from docling.models.document_picture_classifier import (
|
||||
DocumentPictureClassifier,
|
||||
DocumentPictureClassifierOptions,
|
||||
)
|
||||
from docling.models.factories import get_ocr_factory, get_picture_description_factory
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
|
||||
from docling.models.page_preprocessing_model import (
|
||||
PagePreprocessingModel,
|
||||
PagePreprocessingOptions,
|
||||
)
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.pipeline.base_pipeline import BasePipeline
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
from docling.utils.utils import chunkify
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadedItem:
|
||||
"""Item flowing through the threaded pipeline with document context"""
|
||||
|
||||
payload: Page
|
||||
conv_res_id: int
|
||||
conv_res: ConversionResult
|
||||
page_no: int = -1
|
||||
error: Optional[Exception] = None
|
||||
is_failed: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Ensure proper initialization of page number"""
|
||||
if self.page_no == -1 and isinstance(self.payload, Page):
|
||||
self.page_no = self.payload.page_no
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""Result of processing with error tracking for partial results"""
|
||||
|
||||
pages: List[Page] = field(default_factory=list)
|
||||
failed_pages: List[Tuple[int, Exception]] = field(default_factory=list)
|
||||
total_expected: int = 0
|
||||
|
||||
@property
|
||||
def success_count(self) -> int:
|
||||
return len(self.pages)
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return len(self.failed_pages)
|
||||
|
||||
@property
|
||||
def is_partial_success(self) -> bool:
|
||||
return self.success_count > 0 and self.failure_count > 0
|
||||
|
||||
@property
|
||||
def is_complete_failure(self) -> bool:
|
||||
return self.success_count == 0 and self.failure_count > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadedQueue:
|
||||
"""Thread-safe queue with backpressure control and memory management"""
|
||||
|
||||
max_size: int = 100
|
||||
items: deque = field(default_factory=deque)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
not_full: threading.Condition = field(init=False)
|
||||
not_empty: threading.Condition = field(init=False)
|
||||
closed: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.not_full = threading.Condition(self.lock)
|
||||
self.not_empty = threading.Condition(self.lock)
|
||||
|
||||
def put(self, item: ThreadedItem, timeout: Optional[float] = None) -> bool:
|
||||
"""Put item with backpressure control"""
|
||||
with self.not_full:
|
||||
if self.closed:
|
||||
return False
|
||||
|
||||
start_time = time.time()
|
||||
while len(self.items) >= self.max_size and not self.closed:
|
||||
if timeout is not None:
|
||||
remaining = timeout - (time.time() - start_time)
|
||||
if remaining <= 0:
|
||||
return False
|
||||
self.not_full.wait(remaining)
|
||||
else:
|
||||
self.not_full.wait()
|
||||
|
||||
if self.closed:
|
||||
return False
|
||||
|
||||
self.items.append(item)
|
||||
self.not_empty.notify()
|
||||
return True
|
||||
|
||||
def get_batch(
|
||||
self, batch_size: int, timeout: Optional[float] = None
|
||||
) -> List[ThreadedItem]:
|
||||
"""Get a batch of items"""
|
||||
with self.not_empty:
|
||||
start_time = time.time()
|
||||
|
||||
# Wait for at least one item
|
||||
while len(self.items) == 0 and not self.closed:
|
||||
if timeout is not None:
|
||||
remaining = timeout - (time.time() - start_time)
|
||||
if remaining <= 0:
|
||||
return []
|
||||
self.not_empty.wait(remaining)
|
||||
else:
|
||||
self.not_empty.wait()
|
||||
|
||||
# Collect batch
|
||||
batch: List[ThreadedItem] = []
|
||||
while len(batch) < batch_size and len(self.items) > 0:
|
||||
batch.append(self.items.popleft())
|
||||
|
||||
if batch:
|
||||
self.not_full.notify_all()
|
||||
|
||||
return batch
|
||||
|
||||
def close(self):
|
||||
"""Close the queue and wake up waiting threads"""
|
||||
with self.lock:
|
||||
self.closed = True
|
||||
self.not_empty.notify_all()
|
||||
self.not_full.notify_all()
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
with self.lock:
|
||||
return len(self.items) == 0
|
||||
|
||||
def size(self) -> int:
|
||||
with self.lock:
|
||||
return len(self.items)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources and clear items"""
|
||||
with self.lock:
|
||||
self.items.clear()
|
||||
self.closed = True
|
||||
|
||||
|
||||
class ThreadedPipelineStage:
|
||||
"""A pipeline stage that processes items using dedicated threads"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model: Any,
|
||||
batch_size: int,
|
||||
batch_timeout: float,
|
||||
queue_max_size: int,
|
||||
):
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.batch_timeout = batch_timeout
|
||||
self.input_queue = ThreadedQueue(max_size=queue_max_size)
|
||||
self.output_queues: List[ThreadedQueue] = []
|
||||
self.running = False
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
|
||||
def add_output_queue(self, queue: ThreadedQueue):
|
||||
"""Connect this stage to an output queue"""
|
||||
self.output_queues.append(queue)
|
||||
|
||||
def start(self):
|
||||
"""Start the stage processing thread"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run, name=f"Stage-{self.name}")
|
||||
self.thread.daemon = False # Ensure proper shutdown
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the stage processing"""
|
||||
self.running = False
|
||||
self.input_queue.close()
|
||||
if self.thread:
|
||||
self.thread.join(timeout=30.0) # Reasonable timeout for shutdown
|
||||
if self.thread.is_alive():
|
||||
_log.warning(f"Stage {self.name} thread did not shutdown gracefully")
|
||||
|
||||
def _run(self):
|
||||
"""Main processing loop for the stage"""
|
||||
try:
|
||||
while self.running:
|
||||
batch = self.input_queue.get_batch(
|
||||
self.batch_size, timeout=self.batch_timeout
|
||||
)
|
||||
|
||||
if not batch and self.input_queue.closed:
|
||||
break
|
||||
|
||||
if batch:
|
||||
try:
|
||||
processed_items = self._process_batch(batch)
|
||||
self._send_to_outputs(processed_items)
|
||||
except Exception as e:
|
||||
_log.error(f"Error in stage {self.name}: {e}", exc_info=True)
|
||||
# Send failed items downstream for partial processing
|
||||
failed_items = []
|
||||
for item in batch:
|
||||
item.is_failed = True
|
||||
item.error = e
|
||||
failed_items.append(item)
|
||||
self._send_to_outputs(failed_items)
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Fatal error in stage {self.name}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Close output queues when done
|
||||
for queue in self.output_queues:
|
||||
queue.close()
|
||||
|
||||
def _process_batch(self, batch: List[ThreadedItem]) -> List[ThreadedItem]:
|
||||
"""Process a batch through the model with error handling"""
|
||||
# Group by document to maintain document integrity
|
||||
grouped_by_doc = defaultdict(list)
|
||||
for item in batch:
|
||||
grouped_by_doc[item.conv_res_id].append(item)
|
||||
|
||||
processed_items = []
|
||||
for conv_res_id, items in grouped_by_doc.items():
|
||||
try:
|
||||
# Filter out already failed items
|
||||
valid_items = [item for item in items if not item.is_failed]
|
||||
failed_items = [item for item in items if item.is_failed]
|
||||
|
||||
if valid_items:
|
||||
conv_res = valid_items[0].conv_res
|
||||
pages = [item.payload for item in valid_items]
|
||||
|
||||
# Process through model
|
||||
processed_pages = list(self.model(conv_res, pages))
|
||||
|
||||
# Re-wrap processed pages
|
||||
for i, page in enumerate(processed_pages):
|
||||
processed_items.append(
|
||||
ThreadedItem(
|
||||
payload=page,
|
||||
conv_res_id=valid_items[i].conv_res_id,
|
||||
conv_res=valid_items[i].conv_res,
|
||||
page_no=valid_items[i].page_no,
|
||||
)
|
||||
)
|
||||
|
||||
# Pass through failed items for downstream handling
|
||||
processed_items.extend(failed_items)
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Model {self.name} failed for document {conv_res_id}: {e}")
|
||||
# Mark all items as failed but continue processing
|
||||
for item in items:
|
||||
item.is_failed = True
|
||||
item.error = e
|
||||
processed_items.append(item)
|
||||
|
||||
return processed_items
|
||||
|
||||
def _send_to_outputs(self, items: List[ThreadedItem]):
|
||||
"""Send processed items to output queues"""
|
||||
for item in items:
|
||||
for queue in self.output_queues:
|
||||
# Use timeout to prevent blocking indefinitely
|
||||
if not queue.put(item, timeout=5.0):
|
||||
_log.warning(
|
||||
f"Failed to send item from {self.name} due to backpressure"
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up stage resources"""
|
||||
if self.input_queue:
|
||||
self.input_queue.cleanup()
|
||||
for queue in self.output_queues:
|
||||
queue.cleanup()
|
||||
|
||||
|
||||
class ThreadedStandardPdfPipeline(BasePipeline):
|
||||
"""
|
||||
A threaded pipeline implementation that processes pages through
|
||||
dedicated stage threads with batching and backpressure control.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline_options: ThreadedPdfPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.pipeline_options: ThreadedPdfPipelineOptions = pipeline_options
|
||||
|
||||
# Initialize attributes with proper type annotations
|
||||
self.keep_backend: bool = False
|
||||
self.keep_images: bool = False
|
||||
|
||||
# Model attributes - will be initialized in _initialize_models
|
||||
self.preprocessing_model: PagePreprocessingModel
|
||||
self.ocr_model: Any # OCR models have different base types from factory
|
||||
self.layout_model: LayoutModel
|
||||
self.table_model: TableStructureModel
|
||||
self.assemble_model: PageAssembleModel
|
||||
self.reading_order_model: ReadingOrderModel
|
||||
|
||||
self._initialize_models()
|
||||
self._setup_pipeline()
|
||||
|
||||
# Use weak references for memory management
|
||||
self._document_tracker: weakref.WeakValueDictionary[int, ConversionResult] = (
|
||||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._document_lock = threading.Lock()
|
||||
|
||||
def _get_artifacts_path(self) -> Optional[Path]:
|
||||
"""Get artifacts path from options or settings"""
|
||||
artifacts_path = None
|
||||
if self.pipeline_options.artifacts_path is not None:
|
||||
artifacts_path = Path(self.pipeline_options.artifacts_path).expanduser()
|
||||
elif settings.artifacts_path is not None:
|
||||
artifacts_path = Path(settings.artifacts_path).expanduser()
|
||||
|
||||
if artifacts_path is not None and not artifacts_path.is_dir():
|
||||
raise RuntimeError(
|
||||
f"The value of {artifacts_path=} is not valid. "
|
||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||
)
|
||||
return artifacts_path
|
||||
|
||||
def _get_ocr_model(self, artifacts_path: Optional[Path] = None):
|
||||
"""Get OCR model instance"""
|
||||
factory = get_ocr_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
return factory.create_instance(
|
||||
options=self.pipeline_options.ocr_options,
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
def _get_picture_description_model(self, artifacts_path: Optional[Path] = None):
|
||||
"""Get picture description model instance"""
|
||||
factory = get_picture_description_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
return factory.create_instance(
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
enabled=self.pipeline_options.do_picture_description,
|
||||
enable_remote_services=self.pipeline_options.enable_remote_services,
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
def _initialize_models(self):
|
||||
"""Initialize all pipeline models"""
|
||||
artifacts_path = self._get_artifacts_path()
|
||||
|
||||
# Check if we need to keep images for processing
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
self.keep_images = (
|
||||
self.pipeline_options.generate_page_images
|
||||
or self.pipeline_options.generate_picture_images
|
||||
or self.pipeline_options.generate_table_images
|
||||
)
|
||||
|
||||
self.preprocessing_model = PagePreprocessingModel(
|
||||
options=PagePreprocessingOptions(
|
||||
images_scale=self.pipeline_options.images_scale,
|
||||
)
|
||||
)
|
||||
|
||||
self.ocr_model = self._get_ocr_model(artifacts_path)
|
||||
|
||||
self.layout_model = LayoutModel(
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
options=self.pipeline_options.layout_options,
|
||||
)
|
||||
|
||||
self.table_model = TableStructureModel(
|
||||
enabled=self.pipeline_options.do_table_structure,
|
||||
artifacts_path=artifacts_path,
|
||||
options=self.pipeline_options.table_structure_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
self.assemble_model = PageAssembleModel(options=PageAssembleOptions())
|
||||
|
||||
# Reading order and enrichment models
|
||||
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
|
||||
|
||||
# Initialize enrichment models and add only enabled ones to enrichment_pipe
|
||||
self.enrichment_pipe = []
|
||||
|
||||
# Code Formula Enrichment Model
|
||||
code_formula_model = CodeFormulaModel(
|
||||
enabled=self.pipeline_options.do_code_enrichment
|
||||
or self.pipeline_options.do_formula_enrichment,
|
||||
artifacts_path=artifacts_path,
|
||||
options=CodeFormulaModelOptions(
|
||||
do_code_enrichment=self.pipeline_options.do_code_enrichment,
|
||||
do_formula_enrichment=self.pipeline_options.do_formula_enrichment,
|
||||
),
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
if code_formula_model.enabled:
|
||||
self.enrichment_pipe.append(code_formula_model)
|
||||
|
||||
# Document Picture Classifier
|
||||
picture_classifier = DocumentPictureClassifier(
|
||||
enabled=self.pipeline_options.do_picture_classification,
|
||||
artifacts_path=artifacts_path,
|
||||
options=DocumentPictureClassifierOptions(),
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
if picture_classifier.enabled:
|
||||
self.enrichment_pipe.append(picture_classifier)
|
||||
|
||||
# Picture description model
|
||||
picture_description_model = self._get_picture_description_model(artifacts_path)
|
||||
if picture_description_model is not None and picture_description_model.enabled:
|
||||
self.enrichment_pipe.append(picture_description_model)
|
||||
|
||||
# Determine if we need to keep backend for enrichment
|
||||
if (
|
||||
self.pipeline_options.do_formula_enrichment
|
||||
or self.pipeline_options.do_code_enrichment
|
||||
or self.pipeline_options.do_picture_classification
|
||||
or self.pipeline_options.do_picture_description
|
||||
):
|
||||
self.keep_backend = True
|
||||
|
||||
def _setup_pipeline(self):
|
||||
"""Setup the pipeline stages and connections with proper typing"""
|
||||
# Use pipeline options directly - they have proper defaults
|
||||
opts = self.pipeline_options
|
||||
|
||||
# Create pipeline stages
|
||||
self.preprocess_stage = ThreadedPipelineStage(
|
||||
"preprocess",
|
||||
self.preprocessing_model,
|
||||
1,
|
||||
opts.batch_timeout_seconds,
|
||||
opts.queue_max_size,
|
||||
)
|
||||
self.ocr_stage = ThreadedPipelineStage(
|
||||
"ocr",
|
||||
self.ocr_model,
|
||||
opts.ocr_batch_size,
|
||||
opts.batch_timeout_seconds,
|
||||
opts.queue_max_size,
|
||||
)
|
||||
self.layout_stage = ThreadedPipelineStage(
|
||||
"layout",
|
||||
self.layout_model,
|
||||
opts.layout_batch_size,
|
||||
opts.batch_timeout_seconds,
|
||||
opts.queue_max_size,
|
||||
)
|
||||
self.table_stage = ThreadedPipelineStage(
|
||||
"table",
|
||||
self.table_model,
|
||||
opts.table_batch_size,
|
||||
opts.batch_timeout_seconds,
|
||||
opts.queue_max_size,
|
||||
)
|
||||
self.assemble_stage = ThreadedPipelineStage(
|
||||
"assemble",
|
||||
self.assemble_model,
|
||||
1,
|
||||
opts.batch_timeout_seconds,
|
||||
opts.queue_max_size,
|
||||
)
|
||||
|
||||
# Create output queue for final results
|
||||
self.output_queue = ThreadedQueue(max_size=opts.queue_max_size)
|
||||
|
||||
# Connect stages in pipeline order
|
||||
self.preprocess_stage.add_output_queue(self.ocr_stage.input_queue)
|
||||
self.ocr_stage.add_output_queue(self.layout_stage.input_queue)
|
||||
self.layout_stage.add_output_queue(self.table_stage.input_queue)
|
||||
self.table_stage.add_output_queue(self.assemble_stage.input_queue)
|
||||
self.assemble_stage.add_output_queue(self.output_queue)
|
||||
|
||||
self.stages = [
|
||||
self.preprocess_stage,
|
||||
self.ocr_stage,
|
||||
self.layout_stage,
|
||||
self.table_stage,
|
||||
self.assemble_stage,
|
||||
]
|
||||
|
||||
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
"""Build document by processing pages through threaded pipeline"""
|
||||
if not isinstance(conv_res.input._backend, PdfDocumentBackend):
|
||||
raise RuntimeError(
|
||||
f"The selected backend {type(conv_res.input._backend).__name__} for {conv_res.input.file} is not a PDF backend."
|
||||
)
|
||||
|
||||
with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT):
|
||||
# Initialize pages
|
||||
start_page, end_page = conv_res.input.limits.page_range
|
||||
pages_to_process = []
|
||||
|
||||
for i in range(conv_res.input.page_count):
|
||||
if (start_page - 1) <= i <= (end_page - 1):
|
||||
page = Page(page_no=i)
|
||||
conv_res.pages.append(page)
|
||||
|
||||
# Initialize page backend
|
||||
page._backend = conv_res.input._backend.load_page(i)
|
||||
if page._backend and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
pages_to_process.append(page)
|
||||
|
||||
if not pages_to_process:
|
||||
conv_res.status = ConversionStatus.FAILURE
|
||||
return conv_res
|
||||
|
||||
# Register document for tracking with weak reference
|
||||
doc_id = id(conv_res)
|
||||
with self._document_lock:
|
||||
self._document_tracker[doc_id] = conv_res
|
||||
|
||||
# Start pipeline stages
|
||||
for stage in self.stages:
|
||||
stage.start()
|
||||
|
||||
try:
|
||||
# Feed pages into pipeline
|
||||
self._feed_pipeline(pages_to_process, conv_res)
|
||||
|
||||
# Collect results from pipeline with partial processing support
|
||||
result = self._collect_results_with_recovery(
|
||||
conv_res, len(pages_to_process)
|
||||
)
|
||||
|
||||
# Update conv_res with processed pages and handle partial results
|
||||
self._update_document_with_results(conv_res, result)
|
||||
|
||||
finally:
|
||||
# Stop pipeline stages
|
||||
for stage in self.stages:
|
||||
stage.stop()
|
||||
|
||||
# Cleanup stage resources
|
||||
for stage in self.stages:
|
||||
stage.cleanup()
|
||||
|
||||
# Cleanup output queue
|
||||
self.output_queue.cleanup()
|
||||
|
||||
# Cleanup document tracking
|
||||
with self._document_lock:
|
||||
self._document_tracker.pop(doc_id, None)
|
||||
|
||||
return conv_res
|
||||
|
||||
def _feed_pipeline(self, pages: List[Page], conv_res: ConversionResult):
|
||||
"""Feed pages into the pipeline"""
|
||||
for page in pages:
|
||||
item = ThreadedItem(
|
||||
payload=page,
|
||||
conv_res_id=id(conv_res),
|
||||
conv_res=conv_res,
|
||||
page_no=page.page_no,
|
||||
)
|
||||
|
||||
# Feed into first stage with timeout
|
||||
if not self.preprocess_stage.input_queue.put(
|
||||
item, timeout=self.pipeline_options.stage_timeout_seconds
|
||||
):
|
||||
_log.warning(f"Failed to feed page {page.page_no} due to backpressure")
|
||||
|
||||
def _collect_results_with_recovery(
|
||||
self, conv_res: ConversionResult, expected_count: int
|
||||
) -> ProcessingResult:
|
||||
"""Collect processed pages from the pipeline with partial result support"""
|
||||
result = ProcessingResult(total_expected=expected_count)
|
||||
doc_id = id(conv_res)
|
||||
|
||||
# Collect from output queue
|
||||
while len(result.pages) + len(result.failed_pages) < expected_count:
|
||||
batch = self.output_queue.get_batch(
|
||||
batch_size=expected_count
|
||||
- len(result.pages)
|
||||
- len(result.failed_pages),
|
||||
timeout=self.pipeline_options.collection_timeout_seconds,
|
||||
)
|
||||
|
||||
if not batch:
|
||||
# Timeout reached, log missing pages
|
||||
missing_count = (
|
||||
expected_count - len(result.pages) - len(result.failed_pages)
|
||||
)
|
||||
if missing_count > 0:
|
||||
_log.warning(f"Pipeline timeout: missing {missing_count} pages")
|
||||
break
|
||||
|
||||
for item in batch:
|
||||
if item.conv_res_id == doc_id:
|
||||
if item.is_failed or item.error is not None:
|
||||
result.failed_pages.append(
|
||||
(item.page_no, item.error or Exception("Unknown error"))
|
||||
)
|
||||
_log.warning(
|
||||
f"Page {item.page_no} failed processing: {item.error}"
|
||||
)
|
||||
else:
|
||||
result.pages.append(item.payload)
|
||||
|
||||
return result
|
||||
|
||||
def _update_document_with_results(
|
||||
self, conv_res: ConversionResult, result: ProcessingResult
|
||||
):
|
||||
"""Update document with processing results and handle partial success"""
|
||||
# Update conv_res with successfully processed pages
|
||||
page_map = {p.page_no: p for p in result.pages}
|
||||
valid_pages = []
|
||||
|
||||
for page in conv_res.pages:
|
||||
if page.page_no in page_map:
|
||||
valid_pages.append(page_map[page.page_no])
|
||||
elif not any(
|
||||
failed_page_no == page.page_no
|
||||
for failed_page_no, _ in result.failed_pages
|
||||
):
|
||||
# Page wasn't processed but also didn't explicitly fail - keep original
|
||||
valid_pages.append(page)
|
||||
|
||||
conv_res.pages = valid_pages
|
||||
|
||||
# Handle partial results
|
||||
if result.is_partial_success:
|
||||
_log.warning(
|
||||
f"Partial processing success: {result.success_count} pages succeeded, "
|
||||
f"{result.failure_count} pages failed"
|
||||
)
|
||||
conv_res.status = ConversionStatus.PARTIAL_SUCCESS
|
||||
elif result.is_complete_failure:
|
||||
_log.error("Complete processing failure: all pages failed")
|
||||
conv_res.status = ConversionStatus.FAILURE
|
||||
elif result.success_count > 0:
|
||||
# All expected pages processed successfully
|
||||
conv_res.status = ConversionStatus.SUCCESS
|
||||
|
||||
# Clean up page resources if not keeping images
|
||||
if not self.keep_images:
|
||||
for page in conv_res.pages:
|
||||
# _image_cache is always present on Page objects, no need for hasattr
|
||||
page._image_cache = {}
|
||||
|
||||
# Clean up page backends if not keeping them
|
||||
if not self.keep_backend:
|
||||
for page in conv_res.pages:
|
||||
if page._backend is not None:
|
||||
page._backend.unload()
|
||||
|
||||
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
"""Assemble the final document from processed pages"""
|
||||
all_elements = []
|
||||
all_headers = []
|
||||
all_body = []
|
||||
|
||||
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
|
||||
for p in conv_res.pages:
|
||||
if p.assembled is not None:
|
||||
for el in p.assembled.body:
|
||||
all_body.append(el)
|
||||
for el in p.assembled.headers:
|
||||
all_headers.append(el)
|
||||
for el in p.assembled.elements:
|
||||
all_elements.append(el)
|
||||
|
||||
conv_res.assembled = AssembledUnit(
|
||||
elements=all_elements, headers=all_headers, body=all_body
|
||||
)
|
||||
|
||||
conv_res.document = self.reading_order_model(conv_res)
|
||||
|
||||
# Generate page images
|
||||
if self.pipeline_options.generate_page_images:
|
||||
for page in conv_res.pages:
|
||||
if page.image is not None:
|
||||
page_no = page.page_no + 1
|
||||
conv_res.document.pages[page_no].image = ImageRef.from_pil(
|
||||
page.image, dpi=int(72 * self.pipeline_options.images_scale)
|
||||
)
|
||||
|
||||
# Generate element images
|
||||
self._generate_element_images(conv_res)
|
||||
|
||||
# Aggregate confidence scores
|
||||
self._aggregate_confidence(conv_res)
|
||||
|
||||
return conv_res
|
||||
|
||||
def _generate_element_images(self, conv_res: ConversionResult):
|
||||
"""Generate images for picture and table elements"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
if (
|
||||
self.pipeline_options.generate_picture_images
|
||||
or self.pipeline_options.generate_table_images
|
||||
):
|
||||
scale = self.pipeline_options.images_scale
|
||||
for element, _level in conv_res.document.iterate_items():
|
||||
if not isinstance(element, DocItem) or len(element.prov) == 0:
|
||||
continue
|
||||
if (
|
||||
isinstance(element, PictureItem)
|
||||
and self.pipeline_options.generate_picture_images
|
||||
) or (
|
||||
isinstance(element, TableItem)
|
||||
and self.pipeline_options.generate_table_images
|
||||
):
|
||||
page_ix = element.prov[0].page_no - 1
|
||||
page = next(
|
||||
(p for p in conv_res.pages if p.page_no == page_ix), None
|
||||
)
|
||||
if (
|
||||
page is not None
|
||||
and page.size is not None
|
||||
and page.image is not None
|
||||
):
|
||||
crop_bbox = (
|
||||
element.prov[0]
|
||||
.bbox.scaled(scale=scale)
|
||||
.to_top_left_origin(
|
||||
page_height=page.size.height * scale
|
||||
)
|
||||
)
|
||||
cropped_im = page.image.crop(crop_bbox.as_tuple())
|
||||
element.image = ImageRef.from_pil(
|
||||
cropped_im, dpi=int(72 * scale)
|
||||
)
|
||||
|
||||
def _aggregate_confidence(self, conv_res: ConversionResult):
|
||||
"""Aggregate confidence scores across pages"""
|
||||
if len(conv_res.pages) > 0:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=RuntimeWarning,
|
||||
message="Mean of empty slice|All-NaN slice encountered",
|
||||
)
|
||||
conv_res.confidence.layout_score = float(
|
||||
np.nanmean(
|
||||
[c.layout_score for c in conv_res.confidence.pages.values()]
|
||||
)
|
||||
)
|
||||
conv_res.confidence.parse_score = float(
|
||||
np.nanquantile(
|
||||
[c.parse_score for c in conv_res.confidence.pages.values()],
|
||||
q=0.1,
|
||||
)
|
||||
)
|
||||
conv_res.confidence.table_score = float(
|
||||
np.nanmean(
|
||||
[c.table_score for c in conv_res.confidence.pages.values()]
|
||||
)
|
||||
)
|
||||
conv_res.confidence.ocr_score = float(
|
||||
np.nanmean(
|
||||
[c.ocr_score for c in conv_res.confidence.pages.values()]
|
||||
)
|
||||
)
|
||||
|
||||
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
"""Run enrichment models on the document"""
|
||||
|
||||
def _prepare_elements(conv_res: ConversionResult, model: Any) -> Iterable[Any]:
|
||||
for doc_element, _level in conv_res.document.iterate_items():
|
||||
prepared_element = model.prepare_element(
|
||||
conv_res=conv_res, element=doc_element
|
||||
)
|
||||
if prepared_element is not None:
|
||||
yield prepared_element
|
||||
|
||||
with TimeRecorder(conv_res, "doc_enrich", scope=ProfilingScope.DOCUMENT):
|
||||
for model in self.enrichment_pipe:
|
||||
for element_batch in chunkify(
|
||||
_prepare_elements(conv_res, model),
|
||||
model.elements_batch_size,
|
||||
):
|
||||
for element in model(
|
||||
doc=conv_res.document, element_batch=element_batch
|
||||
): # Must exhaust!
|
||||
pass
|
||||
|
||||
return conv_res
|
||||
|
||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||
"""Determine the final conversion status"""
|
||||
if conv_res.status == ConversionStatus.PARTIAL_SUCCESS:
|
||||
return ConversionStatus.PARTIAL_SUCCESS
|
||||
elif conv_res.pages and conv_res.document:
|
||||
return ConversionStatus.SUCCESS
|
||||
else:
|
||||
return ConversionStatus.FAILURE
|
||||
|
||||
@classmethod
|
||||
def get_default_options(cls) -> ThreadedPdfPipelineOptions:
|
||||
return ThreadedPdfPipelineOptions()
|
||||
|
||||
@classmethod
|
||||
def is_backend_supported(cls, backend):
|
||||
return isinstance(backend, PdfDocumentBackend)
|
@ -41,7 +41,7 @@ authors = [
|
||||
{ name = "Panos Vagenas", email = "pva@zurich.ibm.com" },
|
||||
{ name = "Peter Staar", email = "taa@zurich.ibm.com" },
|
||||
]
|
||||
requires-python = '>=3.11,<4.0'
|
||||
requires-python = '>=3.9,<4.0'
|
||||
dependencies = [
|
||||
'pydantic (>=2.0.0,<3.0.0)',
|
||||
'docling-core[chunking] (>=2.42.0,<3.0.0)',
|
||||
|
Loading…
Reference in New Issue
Block a user