Refactoring into async pipeline primitives and graph

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-07-16 10:12:51 +02:00
parent ef25d03bc8
commit 0be9349884
4 changed files with 654 additions and 671 deletions

View File

@ -1,11 +1,7 @@
import asyncio import asyncio
import logging import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, AsyncIterable, Dict, List, Optional, Tuple from typing import Any, AsyncIterable, Dict, List, Optional, Tuple
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import ConversionStatus, Page from docling.datamodel.base_models import ConversionStatus, Page
from docling.datamodel.document import ConversionResult, InputDocument from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import AsyncPdfPipelineOptions from docling.datamodel.pipeline_options import AsyncPdfPipelineOptions
@ -17,8 +13,6 @@ from docling.models.document_picture_classifier import (
DocumentPictureClassifierOptions, DocumentPictureClassifierOptions,
) )
from docling.models.factories import get_ocr_factory, get_picture_description_factory from docling.models.factories import get_ocr_factory, get_picture_description_factory
# Import the same models used by StandardPdfPipeline
from docling.models.layout_model import LayoutModel from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import ( from docling.models.page_preprocessing_model import (
@ -28,82 +22,36 @@ from docling.models.page_preprocessing_model import (
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.async_base_pipeline import AsyncPipeline from docling.pipeline.async_base_pipeline import AsyncPipeline
from docling.pipeline.resource_manager import ( from docling.pipeline.graph import GraphRunner
AsyncPageTracker, from docling.pipeline.resource_manager import AsyncPageTracker
ConversionResultAccumulator, from docling.pipeline.stages import (
AggregationStage,
BatchProcessorStage,
ExtractionStage,
PageProcessorStage,
SinkStage,
SourceStage,
) )
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@dataclass
class PageBatch:
"""Represents a batch of pages to process through models"""
pages: List[Page] = field(default_factory=list)
conv_results: List[ConversionResult] = field(default_factory=list)
start_time: float = field(default_factory=time.time)
@dataclass
class QueueTerminator:
"""Sentinel value for proper queue termination tracking"""
stage: str
error: Optional[Exception] = None
class UpstreamAwareQueue(asyncio.Queue):
"""Queue that tracks upstream completion to optimize batch processing"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.upstream_complete = asyncio.Event()
self.items_in_flight = 0
async def put(self, item):
"""Track items and upstream completion signals"""
if isinstance(item, QueueTerminator):
self.upstream_complete.set()
else:
self.items_in_flight += 1
await super().put(item)
async def get(self):
"""Track when items are consumed"""
item = await super().get()
if not isinstance(item, QueueTerminator):
self.items_in_flight -= 1
return item
def should_stop_waiting_for_batch(self) -> bool:
"""Determine if we should stop waiting and process current batch"""
return (
self.upstream_complete.is_set()
and self.items_in_flight == 0
and self.empty()
)
class AsyncStandardPdfPipeline(AsyncPipeline): class AsyncStandardPdfPipeline(AsyncPipeline):
"""Async pipeline implementation with cross-document batching using structured concurrency""" """
An async, graph-based pipeline for processing PDFs with cross-document batching.
"""
def __init__(self, pipeline_options: AsyncPdfPipelineOptions): def __init__(self, pipeline_options: AsyncPdfPipelineOptions):
super().__init__(pipeline_options) super().__init__(pipeline_options)
self.pipeline_options: AsyncPdfPipelineOptions = pipeline_options self.pipeline_options: AsyncPdfPipelineOptions = pipeline_options
# Resource management
self.page_tracker = AsyncPageTracker( self.page_tracker = AsyncPageTracker(
keep_images=self._should_keep_images(), keep_images=self._should_keep_images(),
keep_backend=self._should_keep_backend(), keep_backend=self._should_keep_backend(),
) )
# Initialize models (same as StandardPdfPipeline)
self._initialize_models() self._initialize_models()
def _should_keep_images(self) -> bool: def _should_keep_images(self) -> bool:
"""Determine if images should be kept (same logic as StandardPdfPipeline)"""
return ( return (
self.pipeline_options.generate_page_images self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images or self.pipeline_options.generate_picture_images
@ -111,7 +59,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
) )
def _should_keep_backend(self) -> bool: def _should_keep_backend(self) -> bool:
"""Determine if backend should be kept"""
return ( return (
self.pipeline_options.do_formula_enrichment self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment or self.pipeline_options.do_code_enrichment
@ -120,36 +67,26 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
) )
def _initialize_models(self): def _initialize_models(self):
"""Initialize all models (matching StandardPdfPipeline)"""
artifacts_path = self._get_artifacts_path() artifacts_path = self._get_artifacts_path()
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions()) self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
# Build pipeline stages
self.preprocessing_model = PagePreprocessingModel( self.preprocessing_model = PagePreprocessingModel(
options=PagePreprocessingOptions( options=PagePreprocessingOptions(
images_scale=self.pipeline_options.images_scale, images_scale=self.pipeline_options.images_scale,
) )
) )
self.ocr_model = self._get_ocr_model(artifacts_path) self.ocr_model = self._get_ocr_model(artifacts_path)
self.layout_model = LayoutModel( self.layout_model = LayoutModel(
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
accelerator_options=self.pipeline_options.accelerator_options, accelerator_options=self.pipeline_options.accelerator_options,
options=self.pipeline_options.layout_options, options=self.pipeline_options.layout_options,
) )
self.table_model = TableStructureModel( self.table_model = TableStructureModel(
enabled=self.pipeline_options.do_table_structure, enabled=self.pipeline_options.do_table_structure,
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
options=self.pipeline_options.table_structure_options, options=self.pipeline_options.table_structure_options,
accelerator_options=self.pipeline_options.accelerator_options, accelerator_options=self.pipeline_options.accelerator_options,
) )
self.assemble_model = PageAssembleModel(options=PageAssembleOptions()) self.assemble_model = PageAssembleModel(options=PageAssembleOptions())
# Enrichment models
self.code_formula_model = CodeFormulaModel( self.code_formula_model = CodeFormulaModel(
enabled=self.pipeline_options.do_code_enrichment enabled=self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_formula_enrichment, or self.pipeline_options.do_formula_enrichment,
@ -160,20 +97,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
), ),
accelerator_options=self.pipeline_options.accelerator_options, accelerator_options=self.pipeline_options.accelerator_options,
) )
self.picture_classifier = DocumentPictureClassifier( self.picture_classifier = DocumentPictureClassifier(
enabled=self.pipeline_options.do_picture_classification, enabled=self.pipeline_options.do_picture_classification,
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
options=DocumentPictureClassifierOptions(), options=DocumentPictureClassifierOptions(),
accelerator_options=self.pipeline_options.accelerator_options, accelerator_options=self.pipeline_options.accelerator_options,
) )
self.picture_description_model = self._get_picture_description_model( self.picture_description_model = self._get_picture_description_model(
artifacts_path artifacts_path
) )
def _get_artifacts_path(self) -> Optional[str]: def _get_artifacts_path(self) -> Optional[str]:
"""Get artifacts path (same as StandardPdfPipeline)"""
from pathlib import Path from pathlib import Path
artifacts_path = None artifacts_path = None
@ -190,7 +124,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
return artifacts_path return artifacts_path
def _get_ocr_model(self, artifacts_path: Optional[str] = None) -> BaseOcrModel: def _get_ocr_model(self, artifacts_path: Optional[str] = None) -> BaseOcrModel:
"""Get OCR model (same as StandardPdfPipeline)"""
factory = get_ocr_factory( factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins allow_external_plugins=self.pipeline_options.allow_external_plugins
) )
@ -202,7 +135,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
) )
def _get_picture_description_model(self, artifacts_path: Optional[str] = None): def _get_picture_description_model(self, artifacts_path: Optional[str] = None):
"""Get picture description model (same as StandardPdfPipeline)"""
factory = get_picture_description_factory( factory = get_picture_description_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins allow_external_plugins=self.pipeline_options.allow_external_plugins
) )
@ -217,606 +149,115 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def execute_stream( async def execute_stream(
self, input_docs: AsyncIterable[InputDocument] self, input_docs: AsyncIterable[InputDocument]
) -> AsyncIterable[ConversionResult]: ) -> AsyncIterable[ConversionResult]:
"""Main async processing with structured concurrency and proper exception handling""" """Main async processing driven by a pipeline graph."""
# Create queues for pipeline stages stages = [
page_queue = UpstreamAwareQueue( SourceStage("source"),
maxsize=self.pipeline_options.extraction_queue_size ExtractionStage(
) "extractor",
completed_queue = UpstreamAwareQueue() self.page_tracker,
completed_docs = UpstreamAwareQueue() self.pipeline_options.max_concurrent_extractions,
),
# Track active documents for proper termination PageProcessorStage("preprocessor", self.preprocessing_model),
doc_tracker = {"active_docs": 0, "extraction_done": False} BatchProcessorStage(
doc_lock = asyncio.Lock() "ocr",
self.ocr_model,
# Create exception event for coordinated shutdown
exception_event = asyncio.Event()
async def track_document_start():
async with doc_lock:
doc_tracker["active_docs"] += 1
async def track_document_complete():
async with doc_lock:
doc_tracker["active_docs"] -= 1
if doc_tracker["extraction_done"] and doc_tracker["active_docs"] == 0:
# All documents completed
await completed_docs.put(None)
try:
async with asyncio.TaskGroup() as tg:
# Start all tasks
tg.create_task(
self._extract_documents_wrapper(
input_docs,
page_queue,
track_document_start,
exception_event,
doc_tracker,
doc_lock,
)
)
tg.create_task(
self._process_pages_wrapper(
page_queue, completed_queue, exception_event
)
)
tg.create_task(
self._aggregate_results_wrapper(
completed_queue,
completed_docs,
track_document_complete,
exception_event,
)
)
# Yield results as they complete
async for result in self._yield_results(
completed_docs, exception_event
):
yield result
except* Exception as eg:
# Handle exception group from TaskGroup
_log.error(f"Pipeline failed with exceptions: {eg.exceptions}")
# Re-raise the first exception
raise (eg.exceptions[0] if eg.exceptions else RuntimeError("Unknown error"))
finally:
# Ensure cleanup
await self.page_tracker.cleanup_all()
async def _extract_documents_wrapper(
self,
input_docs: AsyncIterable[InputDocument],
page_queue: UpstreamAwareQueue,
track_document_start,
exception_event: asyncio.Event,
doc_tracker: Dict[str, Any],
doc_lock: asyncio.Lock,
):
"""Wrapper for document extraction with exception handling"""
try:
await self._extract_documents_safe(
input_docs,
page_queue,
track_document_start,
exception_event,
)
except Exception:
exception_event.set()
raise
finally:
async with doc_lock:
doc_tracker["extraction_done"] = True
# Send termination signal
await page_queue.put(QueueTerminator("extraction"))
async def _process_pages_wrapper(
self,
page_queue: UpstreamAwareQueue,
completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
):
"""Wrapper for page processing with exception handling"""
try:
await self._process_pages_safe(page_queue, completed_queue, exception_event)
except Exception:
exception_event.set()
raise
finally:
# Send termination signal
await completed_queue.put(QueueTerminator("processing"))
async def _aggregate_results_wrapper(
self,
completed_queue: UpstreamAwareQueue,
completed_docs: UpstreamAwareQueue,
track_document_complete,
exception_event: asyncio.Event,
):
"""Wrapper for result aggregation with exception handling"""
try:
await self._aggregate_results_safe(
completed_queue,
completed_docs,
track_document_complete,
exception_event,
)
except Exception:
exception_event.set()
raise
async def _yield_results(
self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event
):
"""Yield results as they complete"""
while True:
if exception_event.is_set():
break
try:
result = await asyncio.wait_for(completed_docs.get(), timeout=1.0)
if result is None:
break
yield result
except asyncio.TimeoutError:
continue
except Exception:
exception_event.set()
raise
async def _extract_documents_safe(
self,
input_docs: AsyncIterable[InputDocument],
page_queue: UpstreamAwareQueue,
track_document_start,
exception_event: asyncio.Event,
) -> None:
"""Extract pages from documents with exception handling"""
async for in_doc in input_docs:
if exception_event.is_set():
break
await track_document_start()
conv_res = ConversionResult(input=in_doc)
# Validate backend
if not isinstance(conv_res.input._backend, PdfDocumentBackend):
conv_res.status = ConversionStatus.FAILURE
await page_queue.put((None, conv_res)) # Signal failed document
continue
try:
# Initialize document
total_pages = conv_res.input.page_count
await self.page_tracker.register_document(conv_res, total_pages)
# Extract pages with limited concurrency
semaphore = asyncio.Semaphore(
self.pipeline_options.max_concurrent_extractions
)
async def extract_page(page_no: int):
if exception_event.is_set():
return
async with semaphore:
# Create page
page = Page(page_no=page_no)
conv_res.pages.append(page)
# Initialize page backend
page._backend = await asyncio.to_thread(
conv_res.input._backend.load_page, page_no
)
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
await self.page_tracker.track_page_loaded(page, conv_res)
# Send to processing queue
await page_queue.put((page, conv_res))
# Extract all pages concurrently
async with asyncio.TaskGroup() as tg:
for i in range(total_pages):
if exception_event.is_set():
break
start_page, end_page = conv_res.input.limits.page_range
if (start_page - 1) <= i <= (end_page - 1):
tg.create_task(extract_page(i))
except Exception as e:
_log.error(f"Failed to extract document {in_doc.file.name}: {e}")
conv_res.status = ConversionStatus.FAILURE
# Signal document failure
await page_queue.put((None, conv_res))
raise
async def _process_pages_safe(
self,
page_queue: UpstreamAwareQueue,
completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
) -> None:
"""Process pages through model pipeline with proper termination"""
# Process batches through each model stage
preprocessing_queue = UpstreamAwareQueue()
ocr_queue = UpstreamAwareQueue()
layout_queue = UpstreamAwareQueue()
table_queue = UpstreamAwareQueue()
assemble_queue = UpstreamAwareQueue()
# Start processing stages using TaskGroup
async with asyncio.TaskGroup() as tg:
# Preprocessing stage
tg.create_task(
self._batch_process_stage_safe(
page_queue,
preprocessing_queue,
self._preprocess_batch,
1,
0, # No batching for preprocessing
"preprocessing",
exception_event,
)
)
# OCR stage
tg.create_task(
self._batch_process_stage_safe(
preprocessing_queue,
ocr_queue,
self._ocr_batch,
self.pipeline_options.ocr_batch_size, self.pipeline_options.ocr_batch_size,
self.pipeline_options.batch_timeout_seconds, self.pipeline_options.batch_timeout_seconds,
"ocr", ),
exception_event, BatchProcessorStage(
) "layout",
) self.layout_model,
# Layout stage
tg.create_task(
self._batch_process_stage_safe(
ocr_queue,
layout_queue,
self._layout_batch,
self.pipeline_options.layout_batch_size, self.pipeline_options.layout_batch_size,
self.pipeline_options.batch_timeout_seconds, self.pipeline_options.batch_timeout_seconds,
"layout", ),
exception_event, BatchProcessorStage(
) "table",
) self.table_model,
# Table stage
tg.create_task(
self._batch_process_stage_safe(
layout_queue,
table_queue,
self._table_batch,
self.pipeline_options.table_batch_size, self.pipeline_options.table_batch_size,
self.pipeline_options.batch_timeout_seconds, self.pipeline_options.batch_timeout_seconds,
"table", ),
exception_event, PageProcessorStage("assembler", self.assemble_model),
) AggregationStage("aggregator", self.page_tracker, self._finalize_document),
) SinkStage("sink"),
]
# Assembly stage edges = [
tg.create_task( # Main processing path
self._batch_process_stage_safe( {
table_queue, "from_stage": "source",
assemble_queue, "from_output": "out",
self._assemble_batch, "to_stage": "extractor",
1, "to_input": "in",
0, # No batching for assembly },
"assembly", {
exception_event, "from_stage": "extractor",
) "from_output": "out",
) "to_stage": "preprocessor",
"to_input": "in",
},
{
"from_stage": "preprocessor",
"from_output": "out",
"to_stage": "ocr",
"to_input": "in",
},
{
"from_stage": "ocr",
"from_output": "out",
"to_stage": "layout",
"to_input": "in",
},
{
"from_stage": "layout",
"from_output": "out",
"to_stage": "table",
"to_input": "in",
},
{
"from_stage": "table",
"from_output": "out",
"to_stage": "assembler",
"to_input": "in",
},
{
"from_stage": "assembler",
"from_output": "out",
"to_stage": "aggregator",
"to_input": "in",
},
# Failure path
{
"from_stage": "extractor",
"from_output": "fail",
"to_stage": "aggregator",
"to_input": "fail",
},
# Final output
{
"from_stage": "aggregator",
"from_output": "out",
"to_stage": "sink",
"to_input": "in",
},
]
# Finalization stage runner = GraphRunner(stages, edges)
tg.create_task( source_config = {"stage": "source", "channel": "out"}
self._finalize_pages_safe( sink_config = {"stage": "sink", "channel": "in"}
assemble_queue, completed_queue, exception_event
)
)
async def _batch_process_stage_safe(
self,
input_queue: UpstreamAwareQueue,
output_queue: UpstreamAwareQueue,
process_func,
batch_size: int,
timeout: float,
stage_name: str,
exception_event: asyncio.Event,
) -> None:
"""Generic batch processing stage with proper termination handling"""
batch = PageBatch()
try: try:
while not exception_event.is_set(): async for result in runner.run(
# Collect batch input_docs,
try: source_config,
# Get first item or wait for timeout sink_config,
if not batch.pages: self.pipeline_options.extraction_queue_size,
item = await input_queue.get() ):
yield result
# Check for termination except* Exception as eg:
if isinstance(item, QueueTerminator): _log.error(f"Pipeline failed with exceptions: {eg.exceptions}")
# Propagate termination signal raise (eg.exceptions[0] if eg.exceptions else RuntimeError("Unknown error"))
await output_queue.put(item) finally:
break await self.page_tracker.cleanup_all()
# Handle failed document signal
if item[0] is None:
# Pass through failure signal
await output_queue.put(item)
continue
batch.pages.append(item[0])
batch.conv_results.append(item[1])
# Try to fill batch up to batch_size or until upstream exhausted
while len(batch.pages) < batch_size:
# Check if upstream is exhausted and we should process immediately
if input_queue.should_stop_waiting_for_batch():
break
# Calculate remaining time for batch timeout
remaining_time = timeout - (time.time() - batch.start_time)
# If upstream is complete, use minimal timeout to drain quickly
if input_queue.upstream_complete.is_set():
remaining_time = min(remaining_time, 0.05)
if remaining_time <= 0:
break
try:
item = await asyncio.wait_for(
input_queue.get(), timeout=remaining_time
)
# Check for termination
if isinstance(item, QueueTerminator):
# Put it back and process current batch
await input_queue.put(item)
break
# Handle failed document signal
if item[0] is None:
# Put it back and process current batch
await input_queue.put(item)
break
batch.pages.append(item[0])
batch.conv_results.append(item[1])
except asyncio.TimeoutError:
break
# Process batch
if batch.pages:
processed = await process_func(batch)
# Send results to output queue
for page, conv_res in processed:
await output_queue.put((page, conv_res))
# Clear batch
batch = PageBatch()
except Exception as e:
_log.error(f"Error in {stage_name} batch processing: {e}")
# Send failed items downstream
for page, conv_res in zip(batch.pages, batch.conv_results):
await output_queue.put((page, conv_res))
batch = PageBatch()
raise
except Exception as e:
# Set exception event and propagate termination
exception_event.set()
await output_queue.put(QueueTerminator(stage_name, error=e))
raise
async def _preprocess_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Preprocess pages (no actual batching needed)"""
results = []
for page, conv_res in zip(batch.pages, batch.conv_results):
processed_page = await asyncio.to_thread(
lambda: next(iter(self.preprocessing_model(conv_res, [page])))
)
results.append((processed_page, conv_res))
return results
async def _ocr_batch(self, batch: PageBatch) -> List[Tuple[Page, ConversionResult]]:
"""Process OCR in batch"""
# Group by conversion result for proper context
grouped = defaultdict(list)
for page, conv_res in zip(batch.pages, batch.conv_results):
grouped[id(conv_res)].append(page)
results = []
for conv_res_id, pages in grouped.items():
# Find the conv_res
conv_res = next(
cr
for p, cr in zip(batch.pages, batch.conv_results)
if id(cr) == conv_res_id
)
# Process batch through OCR model
processed_pages = await asyncio.to_thread(
lambda: list(self.ocr_model(conv_res, pages))
)
for page in processed_pages:
results.append((page, conv_res))
return results
async def _layout_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Process layout in batch"""
# Similar batching as OCR
grouped = defaultdict(list)
for page, conv_res in zip(batch.pages, batch.conv_results):
grouped[id(conv_res)].append(page)
results = []
for conv_res_id, pages in grouped.items():
conv_res = next(
cr
for p, cr in zip(batch.pages, batch.conv_results)
if id(cr) == conv_res_id
)
processed_pages = await asyncio.to_thread(
lambda: list(self.layout_model(conv_res, pages))
)
for page in processed_pages:
results.append((page, conv_res))
return results
async def _table_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Process tables in batch"""
grouped = defaultdict(list)
for page, conv_res in zip(batch.pages, batch.conv_results):
grouped[id(conv_res)].append(page)
results = []
for conv_res_id, pages in grouped.items():
conv_res = next(
cr
for p, cr in zip(batch.pages, batch.conv_results)
if id(cr) == conv_res_id
)
processed_pages = await asyncio.to_thread(
lambda: list(self.table_model(conv_res, pages))
)
for page in processed_pages:
results.append((page, conv_res))
return results
async def _assemble_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Assemble pages (no actual batching needed)"""
results = []
for page, conv_res in zip(batch.pages, batch.conv_results):
assembled_page = await asyncio.to_thread(
lambda: next(iter(self.assemble_model(conv_res, [page])))
)
results.append((assembled_page, conv_res))
return results
async def _finalize_pages_safe(
self,
input_queue: UpstreamAwareQueue,
output_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
) -> None:
"""Finalize pages and track completion with proper termination"""
try:
while not exception_event.is_set():
item = await input_queue.get()
# Check for termination
if isinstance(item, QueueTerminator):
# Propagate termination signal
await output_queue.put(item)
break
# Handle failed document signal
if item[0] is None:
# Pass through failure signal
await output_queue.put(item)
continue
page, conv_res = item
# Track page completion for resource cleanup
await self.page_tracker.track_page_completion(page, conv_res)
# Send to output
await output_queue.put((page, conv_res))
except Exception as e:
exception_event.set()
await output_queue.put(QueueTerminator("finalization", error=e))
raise
async def _aggregate_results_safe(
self,
completed_queue: UpstreamAwareQueue,
completed_docs: UpstreamAwareQueue,
track_document_complete,
exception_event: asyncio.Event,
) -> None:
"""Aggregate completed pages into documents with proper termination"""
doc_pages = defaultdict(list)
failed_docs = set()
try:
while not exception_event.is_set():
item = await completed_queue.get()
# Check for termination
if isinstance(item, QueueTerminator):
# Finalize any remaining documents
for conv_res_id, pages in doc_pages.items():
if conv_res_id not in failed_docs:
# Find conv_res from first page
conv_res = pages[0][1]
await self._finalize_document(conv_res)
await completed_docs.put(conv_res)
await track_document_complete()
break
# Handle failed document signal
if item[0] is None:
conv_res = item[1]
doc_id = id(conv_res)
failed_docs.add(doc_id)
# Send failed document immediately
await completed_docs.put(conv_res)
await track_document_complete()
continue
page, conv_res = item
doc_id = id(conv_res)
if doc_id not in failed_docs:
doc_pages[doc_id].append((page, conv_res))
# Check if document is complete
if len(doc_pages[doc_id]) == len(conv_res.pages):
await self._finalize_document(conv_res)
await completed_docs.put(conv_res)
await track_document_complete()
del doc_pages[doc_id]
except Exception:
exception_event.set()
# Try to send any completed documents before failing
for conv_res_id, pages in doc_pages.items():
if conv_res_id not in failed_docs and pages:
conv_res = pages[0][1]
conv_res.status = ConversionStatus.PARTIAL_SUCCESS
await completed_docs.put(conv_res)
await track_document_complete()
raise
async def _finalize_document(self, conv_res: ConversionResult) -> None: async def _finalize_document(self, conv_res: ConversionResult) -> None:
"""Finalize a complete document (same as StandardPdfPipeline._assemble_document)""" """Finalize a complete document (same as StandardPdfPipeline._assemble_document)"""

144
docling/pipeline/graph.py Normal file
View File

@ -0,0 +1,144 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, AsyncIterable, Dict, List, Literal, Optional
# Sentinel to signal stream completion
STOP_SENTINEL = object()
@dataclass(slots=True)
class StreamItem:
"""
A wrapper for data flowing through the pipeline, maintaining a link
to the original conversion result context.
"""
payload: Any
conv_res_id: int
conv_res: Any # Opaque reference to ConversionResult
class PipelineStage(ABC):
"""A single, encapsulated step in a processing pipeline graph."""
def __init__(self, name: str):
self.name = name
self.input_queues: Dict[str, asyncio.Queue] = {}
self.output_queues: Dict[str, List[asyncio.Queue]] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
@abstractmethod
async def run(self) -> None:
"""
The core execution logic for the stage. This method is responsible for
consuming from input queues, processing data, and putting results into
output queues.
"""
async def _send_to_outputs(self, channel: str, items: List[StreamItem] | List[Any]):
"""Helper to send processed items to all connected output queues."""
if channel in self.output_queues:
for queue in self.output_queues[channel]:
for item in items:
await queue.put(item)
async def _signal_downstream_completion(self):
"""Signal that this stage is done processing to all output channels."""
for channel_queues in self.output_queues.values():
for queue in channel_queues:
await queue.put(STOP_SENTINEL)
@property
def loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop
class GraphRunner:
"""Connects stages and runs the pipeline graph."""
def __init__(self, stages: List[PipelineStage], edges: List[Dict[str, str]]):
self._stages = {s.name: s for s in stages}
self._edges = edges
def _wire_graph(self, queue_max_size: int):
"""Create queues for edges and connect them to stage inputs and outputs."""
for edge in self._edges:
from_stage, from_output = edge["from_stage"], edge["from_output"]
to_stage, to_input = edge["to_stage"], edge["to_input"]
queue = asyncio.Queue(maxsize=queue_max_size)
# Connect to source stage's output
self._stages[from_stage].output_queues.setdefault(from_output, []).append(
queue
)
# Connect to destination stage's input
self._stages[to_stage].input_queues[to_input] = queue
async def _run_source(
self,
source_stream: AsyncIterable[Any],
source_stage: str,
source_channel: str,
):
"""Feed the graph from an external async iterable."""
output_queues = self._stages[source_stage].output_queues.get(source_channel, [])
async for item in source_stream:
for queue in output_queues:
await queue.put(item)
# Signal completion to all downstream queues
for queue in output_queues:
await queue.put(STOP_SENTINEL)
async def _run_sink(self, sink_stage: str, sink_channel: str) -> AsyncIterable[Any]:
"""Yield results from the graph's final output queue."""
queue = self._stages[sink_stage].input_queues[sink_channel]
while True:
item = await queue.get()
if item is STOP_SENTINEL:
break
yield item
await queue.put(STOP_SENTINEL) # Allow other sinks to terminate
async def run(
self,
source_stream: AsyncIterable,
source_config: Dict[str, str],
sink_config: Dict[str, str],
queue_max_size: int = 32,
) -> AsyncIterable:
"""
Executes the entire pipeline graph.
Args:
source_stream: The initial async iterable to feed the graph.
source_config: Dictionary with "stage" and "channel" for the entry point.
sink_config: Dictionary with "stage" and "channel" for the exit point.
queue_max_size: The max size for the internal asyncio.Queues.
"""
self._wire_graph(queue_max_size)
async with asyncio.TaskGroup() as tg:
# Create a task for the source feeder
tg.create_task(
self._run_source(
source_stream, source_config["stage"], source_config["channel"]
)
)
# Create tasks for all pipeline stages
for stage in self._stages.values():
tg.create_task(stage.run())
# Yield results from the sink
async for result in self._run_sink(
sink_config["stage"], sink_config["channel"]
):
yield result

View File

@ -0,0 +1,125 @@
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Set
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.pipeline.async_base_pipeline import DocumentTracker
_log = logging.getLogger(__name__)
@dataclass
class AsyncPageTracker:
"""Manages page backend lifecycle across documents"""
_doc_trackers: Dict[str, DocumentTracker] = field(default_factory=dict)
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
keep_images: bool = False
keep_backend: bool = False
async def register_document(
self, conv_res: ConversionResult, total_pages: int
) -> str:
"""Register a new document for tracking"""
async with self._lock:
doc_id = str(id(conv_res))
self._doc_trackers[doc_id] = DocumentTracker(
doc_id=doc_id, total_pages=total_pages, conv_result=conv_res
)
return doc_id
async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None:
"""Track when a page backend is loaded"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id in self._doc_trackers and page._backend is not None:
self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend
async def track_page_completion(
self, page: Page, conv_res: ConversionResult
) -> bool:
"""Track page completion and cleanup when all pages done"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id not in self._doc_trackers:
_log.warning(f"Document {doc_id} not registered for tracking")
return False
tracker = self._doc_trackers[doc_id]
tracker.processed_pages += 1
# Clear this page's image cache if needed
if not self.keep_images:
page._image_cache = {}
# If all pages from this document are processed, cleanup
if tracker.processed_pages == tracker.total_pages:
await self._cleanup_document_resources(tracker)
del self._doc_trackers[doc_id]
return True # Document is complete
return False # Document is not yet complete
async def _cleanup_document_resources(self, tracker: DocumentTracker) -> None:
"""Cleanup all resources for a completed document"""
if not self.keep_backend:
# Unload all page backends for this document
for page_no, backend in tracker.page_backends.items():
if backend is not None:
try:
# Run unload in thread to avoid blocking
await asyncio.to_thread(backend.unload)
except Exception as e:
_log.warning(
f"Failed to unload backend for page {page_no}: {e}"
)
tracker.page_backends.clear()
_log.debug(f"Cleaned up resources for document {tracker.doc_id}")
async def cleanup_all(self) -> None:
"""Cleanup all tracked documents - for shutdown"""
async with self._lock:
for tracker in self._doc_trackers.values():
await self._cleanup_document_resources(tracker)
self._doc_trackers.clear()
@dataclass
class ConversionResultAccumulator:
"""Accumulates updates to ConversionResult without immediate mutation"""
_updates: Dict[str, Dict] = field(default_factory=dict) # doc_id -> updates
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
async def accumulate_page_result(
self, page_no: int, conv_res: ConversionResult, updates: Dict
) -> None:
"""Accumulate updates for later application"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id not in self._updates:
self._updates[doc_id] = {}
if page_no not in self._updates[doc_id]:
self._updates[doc_id][page_no] = {}
self._updates[doc_id][page_no].update(updates)
async def flush_to_conv_res(self, conv_res: ConversionResult) -> None:
"""Apply all accumulated updates atomically"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id in self._updates:
# Apply updates
for page_no, updates in self._updates[doc_id].items():
# Find the page and apply updates
for page in conv_res.pages:
if page.page_no == page_no:
for key, value in updates.items():
setattr(page, key, value)
break
del self._updates[doc_id]

273
docling/pipeline/stages.py Normal file
View File

@ -0,0 +1,273 @@
from __future__ import annotations
import asyncio
import time
from collections import defaultdict
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
from docling.datamodel.document import ConversionResult, InputDocument, Page
from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem
class SourceStage(PipelineStage):
"""A placeholder stage to represent the entry point of the graph."""
async def run(self) -> None:
# This stage is driven by the GraphRunner's _run_source method
# and does not have its own execution loop.
pass
class SinkStage(PipelineStage):
"""A placeholder stage to represent the exit point of the graph."""
async def run(self) -> None:
# This stage is read by the GraphRunner's _run_sink method
# and does not have its own execution loop.
pass
class ExtractionStage(PipelineStage):
"""Extracts pages from documents and tracks them."""
def __init__(
self,
name: str,
page_tracker: Any,
max_concurrent_extractions: int,
):
super().__init__(name)
self.page_tracker = page_tracker
self.semaphore = asyncio.Semaphore(max_concurrent_extractions)
self.input_channel = "in"
self.output_channel = "out"
self.failure_channel = "fail"
async def _extract_page(
self, page_no: int, conv_res: ConversionResult
) -> StreamItem | None:
"""Coroutine to extract a single page."""
try:
async with self.semaphore:
page = Page(page_no=page_no)
conv_res.pages.append(page)
page._backend = await self.loop.run_in_executor(
None, conv_res.input._backend.load_page, page_no
)
if page._backend and page._backend.is_valid():
page.size = page._backend.get_size()
await self.page_tracker.track_page_loaded(page, conv_res)
return StreamItem(
payload=page, conv_res_id=id(conv_res), conv_res=conv_res
)
except Exception:
# In case of page-level error, we might log it but continue
# For now, we don't propagate failure here, but in the document level
pass
return None
async def _process_document(self, in_doc: InputDocument):
"""Processes a single document, extracting all its pages."""
conv_res = ConversionResult(input=in_doc)
try:
from docling.backend.pdf_backend import PdfDocumentBackend
if not isinstance(in_doc._backend, PdfDocumentBackend):
raise TypeError("Backend is not a valid PdfDocumentBackend")
total_pages = in_doc.page_count
await self.page_tracker.register_document(conv_res, total_pages)
start_page, end_page = conv_res.input.limits.page_range
page_indices_to_extract = [
i for i in range(total_pages) if (start_page - 1) <= i <= (end_page - 1)
]
tasks = [
self.loop.create_task(self._extract_page(i, conv_res))
for i in page_indices_to_extract
]
pages_extracted = await asyncio.gather(*tasks)
await self._send_to_outputs(
self.output_channel, [p for p in pages_extracted if p]
)
except Exception:
conv_res.status = "FAILURE"
await self._send_to_outputs(self.failure_channel, [conv_res])
async def run(self) -> None:
"""Main loop to consume documents and launch extraction tasks."""
q_in = self.input_queues[self.input_channel]
while True:
doc = await q_in.get()
if doc is STOP_SENTINEL:
await self._signal_downstream_completion()
break
await self._process_document(doc)
class PageProcessorStage(PipelineStage):
"""Applies a synchronous, 1-to-1 processing function to each page."""
def __init__(self, name: str, model: Any):
super().__init__(name)
self.model = model
self.input_channel = "in"
self.output_channel = "out"
async def run(self) -> None:
q_in = self.input_queues[self.input_channel]
while True:
item = await q_in.get()
if item is STOP_SENTINEL:
await self._signal_downstream_completion()
break
# The model call is sync, run in thread to avoid blocking event loop
processed_page = await self.loop.run_in_executor(
None, lambda: next(iter(self.model(item.conv_res, [item.payload])))
)
item.payload = processed_page
await self._send_to_outputs(self.output_channel, [item])
class BatchProcessorStage(PipelineStage):
"""Batches items and applies a synchronous model to the batch."""
def __init__(
self,
name: str,
model: Any,
batch_size: int,
batch_timeout: float,
):
super().__init__(name)
self.model = model
self.batch_size = batch_size
self.batch_timeout = batch_timeout
self.input_channel = "in"
self.output_channel = "out"
async def _collect_batch(self, q_in: asyncio.Queue) -> List[StreamItem] | None:
"""Collects a batch of items from the input queue with a timeout."""
try:
# Wait for the first item without a timeout
first_item = await q_in.get()
if first_item is STOP_SENTINEL:
return None # End of stream
except asyncio.CancelledError:
return None
batch = [first_item]
start_time = self.loop.time()
while len(batch) < self.batch_size:
timeout = self.batch_timeout - (self.loop.time() - start_time)
if timeout <= 0:
break
try:
item = await asyncio.wait_for(q_in.get(), timeout)
if item is STOP_SENTINEL:
# Put sentinel back for other potential consumers or the main loop
await q_in.put(STOP_SENTINEL)
break
batch.append(item)
except asyncio.TimeoutError:
break # Batching timeout reached
return batch
async def run(self) -> None:
q_in = self.input_queues[self.input_channel]
while True:
batch = await self._collect_batch(q_in)
if not batch: # This can be None or an empty list.
await self._signal_downstream_completion()
break
# Group pages by their original ConversionResult
grouped_by_doc = defaultdict(list)
for item in batch:
grouped_by_doc[item.conv_res_id].append(item)
processed_items = []
for conv_res_id, items in grouped_by_doc.items():
conv_res = items[0].conv_res
pages = [item.payload for item in items]
# The model call is sync, run in thread
processed_pages = await self.loop.run_in_executor(
None, lambda: list(self.model(conv_res, pages))
)
# Re-wrap the processed pages into StreamItems
for i, page in enumerate(processed_pages):
processed_items.append(
StreamItem(
payload=page,
conv_res_id=items[i].conv_res_id,
conv_res=items[i].conv_res,
)
)
await self._send_to_outputs(self.output_channel, processed_items)
class AggregationStage(PipelineStage):
"""Aggregates processed pages back into completed documents."""
def __init__(
self,
name: str,
page_tracker: Any,
finalizer_func: Callable[[ConversionResult], Coroutine],
):
super().__init__(name)
self.page_tracker = page_tracker
self.finalizer_func = finalizer_func
self.success_channel = "in"
self.failure_channel = "fail"
self.output_channel = "out"
async def run(self) -> None:
success_q = self.input_queues[self.success_channel]
failure_q = self.input_queues.get(self.failure_channel)
doc_completers: Dict[int, asyncio.Future] = {}
async def handle_successes():
while True:
item = await success_q.get()
if item is STOP_SENTINEL:
break
is_doc_complete = await self.page_tracker.track_page_completion(
item.payload, item.conv_res
)
if is_doc_complete:
await self.finalizer_func(item.conv_res)
await self._send_to_outputs(self.output_channel, [item.conv_res])
if item.conv_res_id in doc_completers:
doc_completers.pop(item.conv_res_id).set_result(True)
async def handle_failures():
if failure_q is None:
return # No failure channel, nothing to do
while True:
failed_res = await failure_q.get()
if failed_res is STOP_SENTINEL:
break
await self._send_to_outputs(self.output_channel, [failed_res])
if id(failed_res) in doc_completers:
doc_completers.pop(id(failed_res)).set_result(True)
# Create tasks only for channels that exist
tasks = [handle_successes()]
if failure_q is not None:
tasks.append(handle_failures())
await asyncio.gather(*tasks)
await self._signal_downstream_completion()