Refactoring into async pipeline primitives and graph

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-07-16 10:12:51 +02:00
parent ef25d03bc8
commit 0be9349884
4 changed files with 654 additions and 671 deletions

View File

@ -1,11 +1,7 @@
import asyncio
import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, AsyncIterable, Dict, List, Optional, Tuple
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import ConversionStatus, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import AsyncPdfPipelineOptions
@ -17,8 +13,6 @@ from docling.models.document_picture_classifier import (
DocumentPictureClassifierOptions,
)
from docling.models.factories import get_ocr_factory, get_picture_description_factory
# Import the same models used by StandardPdfPipeline
from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
@ -28,82 +22,36 @@ from docling.models.page_preprocessing_model import (
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.async_base_pipeline import AsyncPipeline
from docling.pipeline.resource_manager import (
AsyncPageTracker,
ConversionResultAccumulator,
from docling.pipeline.graph import GraphRunner
from docling.pipeline.resource_manager import AsyncPageTracker
from docling.pipeline.stages import (
AggregationStage,
BatchProcessorStage,
ExtractionStage,
PageProcessorStage,
SinkStage,
SourceStage,
)
from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__)
@dataclass
class PageBatch:
"""Represents a batch of pages to process through models"""
pages: List[Page] = field(default_factory=list)
conv_results: List[ConversionResult] = field(default_factory=list)
start_time: float = field(default_factory=time.time)
@dataclass
class QueueTerminator:
"""Sentinel value for proper queue termination tracking"""
stage: str
error: Optional[Exception] = None
class UpstreamAwareQueue(asyncio.Queue):
"""Queue that tracks upstream completion to optimize batch processing"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.upstream_complete = asyncio.Event()
self.items_in_flight = 0
async def put(self, item):
"""Track items and upstream completion signals"""
if isinstance(item, QueueTerminator):
self.upstream_complete.set()
else:
self.items_in_flight += 1
await super().put(item)
async def get(self):
"""Track when items are consumed"""
item = await super().get()
if not isinstance(item, QueueTerminator):
self.items_in_flight -= 1
return item
def should_stop_waiting_for_batch(self) -> bool:
"""Determine if we should stop waiting and process current batch"""
return (
self.upstream_complete.is_set()
and self.items_in_flight == 0
and self.empty()
)
class AsyncStandardPdfPipeline(AsyncPipeline):
"""Async pipeline implementation with cross-document batching using structured concurrency"""
"""
An async, graph-based pipeline for processing PDFs with cross-document batching.
"""
def __init__(self, pipeline_options: AsyncPdfPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: AsyncPdfPipelineOptions = pipeline_options
# Resource management
self.page_tracker = AsyncPageTracker(
keep_images=self._should_keep_images(),
keep_backend=self._should_keep_backend(),
)
# Initialize models (same as StandardPdfPipeline)
self._initialize_models()
def _should_keep_images(self) -> bool:
"""Determine if images should be kept (same logic as StandardPdfPipeline)"""
return (
self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images
@ -111,7 +59,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
)
def _should_keep_backend(self) -> bool:
"""Determine if backend should be kept"""
return (
self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment
@ -120,36 +67,26 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
)
def _initialize_models(self):
"""Initialize all models (matching StandardPdfPipeline)"""
artifacts_path = self._get_artifacts_path()
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
# Build pipeline stages
self.preprocessing_model = PagePreprocessingModel(
options=PagePreprocessingOptions(
images_scale=self.pipeline_options.images_scale,
)
)
self.ocr_model = self._get_ocr_model(artifacts_path)
self.layout_model = LayoutModel(
artifacts_path=artifacts_path,
accelerator_options=self.pipeline_options.accelerator_options,
options=self.pipeline_options.layout_options,
)
self.table_model = TableStructureModel(
enabled=self.pipeline_options.do_table_structure,
artifacts_path=artifacts_path,
options=self.pipeline_options.table_structure_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
self.assemble_model = PageAssembleModel(options=PageAssembleOptions())
# Enrichment models
self.code_formula_model = CodeFormulaModel(
enabled=self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_formula_enrichment,
@ -160,20 +97,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
),
accelerator_options=self.pipeline_options.accelerator_options,
)
self.picture_classifier = DocumentPictureClassifier(
enabled=self.pipeline_options.do_picture_classification,
artifacts_path=artifacts_path,
options=DocumentPictureClassifierOptions(),
accelerator_options=self.pipeline_options.accelerator_options,
)
self.picture_description_model = self._get_picture_description_model(
artifacts_path
)
def _get_artifacts_path(self) -> Optional[str]:
"""Get artifacts path (same as StandardPdfPipeline)"""
from pathlib import Path
artifacts_path = None
@ -190,7 +124,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
return artifacts_path
def _get_ocr_model(self, artifacts_path: Optional[str] = None) -> BaseOcrModel:
"""Get OCR model (same as StandardPdfPipeline)"""
factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
@ -202,7 +135,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
)
def _get_picture_description_model(self, artifacts_path: Optional[str] = None):
"""Get picture description model (same as StandardPdfPipeline)"""
factory = get_picture_description_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
@ -217,607 +149,116 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def execute_stream(
self, input_docs: AsyncIterable[InputDocument]
) -> AsyncIterable[ConversionResult]:
"""Main async processing with structured concurrency and proper exception handling"""
# Create queues for pipeline stages
page_queue = UpstreamAwareQueue(
maxsize=self.pipeline_options.extraction_queue_size
)
completed_queue = UpstreamAwareQueue()
completed_docs = UpstreamAwareQueue()
"""Main async processing driven by a pipeline graph."""
stages = [
SourceStage("source"),
ExtractionStage(
"extractor",
self.page_tracker,
self.pipeline_options.max_concurrent_extractions,
),
PageProcessorStage("preprocessor", self.preprocessing_model),
BatchProcessorStage(
"ocr",
self.ocr_model,
self.pipeline_options.ocr_batch_size,
self.pipeline_options.batch_timeout_seconds,
),
BatchProcessorStage(
"layout",
self.layout_model,
self.pipeline_options.layout_batch_size,
self.pipeline_options.batch_timeout_seconds,
),
BatchProcessorStage(
"table",
self.table_model,
self.pipeline_options.table_batch_size,
self.pipeline_options.batch_timeout_seconds,
),
PageProcessorStage("assembler", self.assemble_model),
AggregationStage("aggregator", self.page_tracker, self._finalize_document),
SinkStage("sink"),
]
# Track active documents for proper termination
doc_tracker = {"active_docs": 0, "extraction_done": False}
doc_lock = asyncio.Lock()
edges = [
# Main processing path
{
"from_stage": "source",
"from_output": "out",
"to_stage": "extractor",
"to_input": "in",
},
{
"from_stage": "extractor",
"from_output": "out",
"to_stage": "preprocessor",
"to_input": "in",
},
{
"from_stage": "preprocessor",
"from_output": "out",
"to_stage": "ocr",
"to_input": "in",
},
{
"from_stage": "ocr",
"from_output": "out",
"to_stage": "layout",
"to_input": "in",
},
{
"from_stage": "layout",
"from_output": "out",
"to_stage": "table",
"to_input": "in",
},
{
"from_stage": "table",
"from_output": "out",
"to_stage": "assembler",
"to_input": "in",
},
{
"from_stage": "assembler",
"from_output": "out",
"to_stage": "aggregator",
"to_input": "in",
},
# Failure path
{
"from_stage": "extractor",
"from_output": "fail",
"to_stage": "aggregator",
"to_input": "fail",
},
# Final output
{
"from_stage": "aggregator",
"from_output": "out",
"to_stage": "sink",
"to_input": "in",
},
]
# Create exception event for coordinated shutdown
exception_event = asyncio.Event()
async def track_document_start():
async with doc_lock:
doc_tracker["active_docs"] += 1
async def track_document_complete():
async with doc_lock:
doc_tracker["active_docs"] -= 1
if doc_tracker["extraction_done"] and doc_tracker["active_docs"] == 0:
# All documents completed
await completed_docs.put(None)
runner = GraphRunner(stages, edges)
source_config = {"stage": "source", "channel": "out"}
sink_config = {"stage": "sink", "channel": "in"}
try:
async with asyncio.TaskGroup() as tg:
# Start all tasks
tg.create_task(
self._extract_documents_wrapper(
input_docs,
page_queue,
track_document_start,
exception_event,
doc_tracker,
doc_lock,
)
)
tg.create_task(
self._process_pages_wrapper(
page_queue, completed_queue, exception_event
)
)
tg.create_task(
self._aggregate_results_wrapper(
completed_queue,
completed_docs,
track_document_complete,
exception_event,
)
)
# Yield results as they complete
async for result in self._yield_results(
completed_docs, exception_event
):
yield result
async for result in runner.run(
input_docs,
source_config,
sink_config,
self.pipeline_options.extraction_queue_size,
):
yield result
except* Exception as eg:
# Handle exception group from TaskGroup
_log.error(f"Pipeline failed with exceptions: {eg.exceptions}")
# Re-raise the first exception
raise (eg.exceptions[0] if eg.exceptions else RuntimeError("Unknown error"))
finally:
# Ensure cleanup
await self.page_tracker.cleanup_all()
async def _extract_documents_wrapper(
self,
input_docs: AsyncIterable[InputDocument],
page_queue: UpstreamAwareQueue,
track_document_start,
exception_event: asyncio.Event,
doc_tracker: Dict[str, Any],
doc_lock: asyncio.Lock,
):
"""Wrapper for document extraction with exception handling"""
try:
await self._extract_documents_safe(
input_docs,
page_queue,
track_document_start,
exception_event,
)
except Exception:
exception_event.set()
raise
finally:
async with doc_lock:
doc_tracker["extraction_done"] = True
# Send termination signal
await page_queue.put(QueueTerminator("extraction"))
async def _process_pages_wrapper(
self,
page_queue: UpstreamAwareQueue,
completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
):
"""Wrapper for page processing with exception handling"""
try:
await self._process_pages_safe(page_queue, completed_queue, exception_event)
except Exception:
exception_event.set()
raise
finally:
# Send termination signal
await completed_queue.put(QueueTerminator("processing"))
async def _aggregate_results_wrapper(
self,
completed_queue: UpstreamAwareQueue,
completed_docs: UpstreamAwareQueue,
track_document_complete,
exception_event: asyncio.Event,
):
"""Wrapper for result aggregation with exception handling"""
try:
await self._aggregate_results_safe(
completed_queue,
completed_docs,
track_document_complete,
exception_event,
)
except Exception:
exception_event.set()
raise
async def _yield_results(
self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event
):
"""Yield results as they complete"""
while True:
if exception_event.is_set():
break
try:
result = await asyncio.wait_for(completed_docs.get(), timeout=1.0)
if result is None:
break
yield result
except asyncio.TimeoutError:
continue
except Exception:
exception_event.set()
raise
async def _extract_documents_safe(
self,
input_docs: AsyncIterable[InputDocument],
page_queue: UpstreamAwareQueue,
track_document_start,
exception_event: asyncio.Event,
) -> None:
"""Extract pages from documents with exception handling"""
async for in_doc in input_docs:
if exception_event.is_set():
break
await track_document_start()
conv_res = ConversionResult(input=in_doc)
# Validate backend
if not isinstance(conv_res.input._backend, PdfDocumentBackend):
conv_res.status = ConversionStatus.FAILURE
await page_queue.put((None, conv_res)) # Signal failed document
continue
try:
# Initialize document
total_pages = conv_res.input.page_count
await self.page_tracker.register_document(conv_res, total_pages)
# Extract pages with limited concurrency
semaphore = asyncio.Semaphore(
self.pipeline_options.max_concurrent_extractions
)
async def extract_page(page_no: int):
if exception_event.is_set():
return
async with semaphore:
# Create page
page = Page(page_no=page_no)
conv_res.pages.append(page)
# Initialize page backend
page._backend = await asyncio.to_thread(
conv_res.input._backend.load_page, page_no
)
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
await self.page_tracker.track_page_loaded(page, conv_res)
# Send to processing queue
await page_queue.put((page, conv_res))
# Extract all pages concurrently
async with asyncio.TaskGroup() as tg:
for i in range(total_pages):
if exception_event.is_set():
break
start_page, end_page = conv_res.input.limits.page_range
if (start_page - 1) <= i <= (end_page - 1):
tg.create_task(extract_page(i))
except Exception as e:
_log.error(f"Failed to extract document {in_doc.file.name}: {e}")
conv_res.status = ConversionStatus.FAILURE
# Signal document failure
await page_queue.put((None, conv_res))
raise
async def _process_pages_safe(
self,
page_queue: UpstreamAwareQueue,
completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
) -> None:
"""Process pages through model pipeline with proper termination"""
# Process batches through each model stage
preprocessing_queue = UpstreamAwareQueue()
ocr_queue = UpstreamAwareQueue()
layout_queue = UpstreamAwareQueue()
table_queue = UpstreamAwareQueue()
assemble_queue = UpstreamAwareQueue()
# Start processing stages using TaskGroup
async with asyncio.TaskGroup() as tg:
# Preprocessing stage
tg.create_task(
self._batch_process_stage_safe(
page_queue,
preprocessing_queue,
self._preprocess_batch,
1,
0, # No batching for preprocessing
"preprocessing",
exception_event,
)
)
# OCR stage
tg.create_task(
self._batch_process_stage_safe(
preprocessing_queue,
ocr_queue,
self._ocr_batch,
self.pipeline_options.ocr_batch_size,
self.pipeline_options.batch_timeout_seconds,
"ocr",
exception_event,
)
)
# Layout stage
tg.create_task(
self._batch_process_stage_safe(
ocr_queue,
layout_queue,
self._layout_batch,
self.pipeline_options.layout_batch_size,
self.pipeline_options.batch_timeout_seconds,
"layout",
exception_event,
)
)
# Table stage
tg.create_task(
self._batch_process_stage_safe(
layout_queue,
table_queue,
self._table_batch,
self.pipeline_options.table_batch_size,
self.pipeline_options.batch_timeout_seconds,
"table",
exception_event,
)
)
# Assembly stage
tg.create_task(
self._batch_process_stage_safe(
table_queue,
assemble_queue,
self._assemble_batch,
1,
0, # No batching for assembly
"assembly",
exception_event,
)
)
# Finalization stage
tg.create_task(
self._finalize_pages_safe(
assemble_queue, completed_queue, exception_event
)
)
async def _batch_process_stage_safe(
self,
input_queue: UpstreamAwareQueue,
output_queue: UpstreamAwareQueue,
process_func,
batch_size: int,
timeout: float,
stage_name: str,
exception_event: asyncio.Event,
) -> None:
"""Generic batch processing stage with proper termination handling"""
batch = PageBatch()
try:
while not exception_event.is_set():
# Collect batch
try:
# Get first item or wait for timeout
if not batch.pages:
item = await input_queue.get()
# Check for termination
if isinstance(item, QueueTerminator):
# Propagate termination signal
await output_queue.put(item)
break
# Handle failed document signal
if item[0] is None:
# Pass through failure signal
await output_queue.put(item)
continue
batch.pages.append(item[0])
batch.conv_results.append(item[1])
# Try to fill batch up to batch_size or until upstream exhausted
while len(batch.pages) < batch_size:
# Check if upstream is exhausted and we should process immediately
if input_queue.should_stop_waiting_for_batch():
break
# Calculate remaining time for batch timeout
remaining_time = timeout - (time.time() - batch.start_time)
# If upstream is complete, use minimal timeout to drain quickly
if input_queue.upstream_complete.is_set():
remaining_time = min(remaining_time, 0.05)
if remaining_time <= 0:
break
try:
item = await asyncio.wait_for(
input_queue.get(), timeout=remaining_time
)
# Check for termination
if isinstance(item, QueueTerminator):
# Put it back and process current batch
await input_queue.put(item)
break
# Handle failed document signal
if item[0] is None:
# Put it back and process current batch
await input_queue.put(item)
break
batch.pages.append(item[0])
batch.conv_results.append(item[1])
except asyncio.TimeoutError:
break
# Process batch
if batch.pages:
processed = await process_func(batch)
# Send results to output queue
for page, conv_res in processed:
await output_queue.put((page, conv_res))
# Clear batch
batch = PageBatch()
except Exception as e:
_log.error(f"Error in {stage_name} batch processing: {e}")
# Send failed items downstream
for page, conv_res in zip(batch.pages, batch.conv_results):
await output_queue.put((page, conv_res))
batch = PageBatch()
raise
except Exception as e:
# Set exception event and propagate termination
exception_event.set()
await output_queue.put(QueueTerminator(stage_name, error=e))
raise
async def _preprocess_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Preprocess pages (no actual batching needed)"""
results = []
for page, conv_res in zip(batch.pages, batch.conv_results):
processed_page = await asyncio.to_thread(
lambda: next(iter(self.preprocessing_model(conv_res, [page])))
)
results.append((processed_page, conv_res))
return results
async def _ocr_batch(self, batch: PageBatch) -> List[Tuple[Page, ConversionResult]]:
"""Process OCR in batch"""
# Group by conversion result for proper context
grouped = defaultdict(list)
for page, conv_res in zip(batch.pages, batch.conv_results):
grouped[id(conv_res)].append(page)
results = []
for conv_res_id, pages in grouped.items():
# Find the conv_res
conv_res = next(
cr
for p, cr in zip(batch.pages, batch.conv_results)
if id(cr) == conv_res_id
)
# Process batch through OCR model
processed_pages = await asyncio.to_thread(
lambda: list(self.ocr_model(conv_res, pages))
)
for page in processed_pages:
results.append((page, conv_res))
return results
async def _layout_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Process layout in batch"""
# Similar batching as OCR
grouped = defaultdict(list)
for page, conv_res in zip(batch.pages, batch.conv_results):
grouped[id(conv_res)].append(page)
results = []
for conv_res_id, pages in grouped.items():
conv_res = next(
cr
for p, cr in zip(batch.pages, batch.conv_results)
if id(cr) == conv_res_id
)
processed_pages = await asyncio.to_thread(
lambda: list(self.layout_model(conv_res, pages))
)
for page in processed_pages:
results.append((page, conv_res))
return results
async def _table_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Process tables in batch"""
grouped = defaultdict(list)
for page, conv_res in zip(batch.pages, batch.conv_results):
grouped[id(conv_res)].append(page)
results = []
for conv_res_id, pages in grouped.items():
conv_res = next(
cr
for p, cr in zip(batch.pages, batch.conv_results)
if id(cr) == conv_res_id
)
processed_pages = await asyncio.to_thread(
lambda: list(self.table_model(conv_res, pages))
)
for page in processed_pages:
results.append((page, conv_res))
return results
async def _assemble_batch(
self, batch: PageBatch
) -> List[Tuple[Page, ConversionResult]]:
"""Assemble pages (no actual batching needed)"""
results = []
for page, conv_res in zip(batch.pages, batch.conv_results):
assembled_page = await asyncio.to_thread(
lambda: next(iter(self.assemble_model(conv_res, [page])))
)
results.append((assembled_page, conv_res))
return results
async def _finalize_pages_safe(
self,
input_queue: UpstreamAwareQueue,
output_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
) -> None:
"""Finalize pages and track completion with proper termination"""
try:
while not exception_event.is_set():
item = await input_queue.get()
# Check for termination
if isinstance(item, QueueTerminator):
# Propagate termination signal
await output_queue.put(item)
break
# Handle failed document signal
if item[0] is None:
# Pass through failure signal
await output_queue.put(item)
continue
page, conv_res = item
# Track page completion for resource cleanup
await self.page_tracker.track_page_completion(page, conv_res)
# Send to output
await output_queue.put((page, conv_res))
except Exception as e:
exception_event.set()
await output_queue.put(QueueTerminator("finalization", error=e))
raise
async def _aggregate_results_safe(
self,
completed_queue: UpstreamAwareQueue,
completed_docs: UpstreamAwareQueue,
track_document_complete,
exception_event: asyncio.Event,
) -> None:
"""Aggregate completed pages into documents with proper termination"""
doc_pages = defaultdict(list)
failed_docs = set()
try:
while not exception_event.is_set():
item = await completed_queue.get()
# Check for termination
if isinstance(item, QueueTerminator):
# Finalize any remaining documents
for conv_res_id, pages in doc_pages.items():
if conv_res_id not in failed_docs:
# Find conv_res from first page
conv_res = pages[0][1]
await self._finalize_document(conv_res)
await completed_docs.put(conv_res)
await track_document_complete()
break
# Handle failed document signal
if item[0] is None:
conv_res = item[1]
doc_id = id(conv_res)
failed_docs.add(doc_id)
# Send failed document immediately
await completed_docs.put(conv_res)
await track_document_complete()
continue
page, conv_res = item
doc_id = id(conv_res)
if doc_id not in failed_docs:
doc_pages[doc_id].append((page, conv_res))
# Check if document is complete
if len(doc_pages[doc_id]) == len(conv_res.pages):
await self._finalize_document(conv_res)
await completed_docs.put(conv_res)
await track_document_complete()
del doc_pages[doc_id]
except Exception:
exception_event.set()
# Try to send any completed documents before failing
for conv_res_id, pages in doc_pages.items():
if conv_res_id not in failed_docs and pages:
conv_res = pages[0][1]
conv_res.status = ConversionStatus.PARTIAL_SUCCESS
await completed_docs.put(conv_res)
await track_document_complete()
raise
async def _finalize_document(self, conv_res: ConversionResult) -> None:
"""Finalize a complete document (same as StandardPdfPipeline._assemble_document)"""
# This matches the logic from StandardPdfPipeline

144
docling/pipeline/graph.py Normal file
View File

@ -0,0 +1,144 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, AsyncIterable, Dict, List, Literal, Optional
# Sentinel to signal stream completion
STOP_SENTINEL = object()
@dataclass(slots=True)
class StreamItem:
"""
A wrapper for data flowing through the pipeline, maintaining a link
to the original conversion result context.
"""
payload: Any
conv_res_id: int
conv_res: Any # Opaque reference to ConversionResult
class PipelineStage(ABC):
"""A single, encapsulated step in a processing pipeline graph."""
def __init__(self, name: str):
self.name = name
self.input_queues: Dict[str, asyncio.Queue] = {}
self.output_queues: Dict[str, List[asyncio.Queue]] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
@abstractmethod
async def run(self) -> None:
"""
The core execution logic for the stage. This method is responsible for
consuming from input queues, processing data, and putting results into
output queues.
"""
async def _send_to_outputs(self, channel: str, items: List[StreamItem] | List[Any]):
"""Helper to send processed items to all connected output queues."""
if channel in self.output_queues:
for queue in self.output_queues[channel]:
for item in items:
await queue.put(item)
async def _signal_downstream_completion(self):
"""Signal that this stage is done processing to all output channels."""
for channel_queues in self.output_queues.values():
for queue in channel_queues:
await queue.put(STOP_SENTINEL)
@property
def loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop
class GraphRunner:
"""Connects stages and runs the pipeline graph."""
def __init__(self, stages: List[PipelineStage], edges: List[Dict[str, str]]):
self._stages = {s.name: s for s in stages}
self._edges = edges
def _wire_graph(self, queue_max_size: int):
"""Create queues for edges and connect them to stage inputs and outputs."""
for edge in self._edges:
from_stage, from_output = edge["from_stage"], edge["from_output"]
to_stage, to_input = edge["to_stage"], edge["to_input"]
queue = asyncio.Queue(maxsize=queue_max_size)
# Connect to source stage's output
self._stages[from_stage].output_queues.setdefault(from_output, []).append(
queue
)
# Connect to destination stage's input
self._stages[to_stage].input_queues[to_input] = queue
async def _run_source(
self,
source_stream: AsyncIterable[Any],
source_stage: str,
source_channel: str,
):
"""Feed the graph from an external async iterable."""
output_queues = self._stages[source_stage].output_queues.get(source_channel, [])
async for item in source_stream:
for queue in output_queues:
await queue.put(item)
# Signal completion to all downstream queues
for queue in output_queues:
await queue.put(STOP_SENTINEL)
async def _run_sink(self, sink_stage: str, sink_channel: str) -> AsyncIterable[Any]:
"""Yield results from the graph's final output queue."""
queue = self._stages[sink_stage].input_queues[sink_channel]
while True:
item = await queue.get()
if item is STOP_SENTINEL:
break
yield item
await queue.put(STOP_SENTINEL) # Allow other sinks to terminate
async def run(
self,
source_stream: AsyncIterable,
source_config: Dict[str, str],
sink_config: Dict[str, str],
queue_max_size: int = 32,
) -> AsyncIterable:
"""
Executes the entire pipeline graph.
Args:
source_stream: The initial async iterable to feed the graph.
source_config: Dictionary with "stage" and "channel" for the entry point.
sink_config: Dictionary with "stage" and "channel" for the exit point.
queue_max_size: The max size for the internal asyncio.Queues.
"""
self._wire_graph(queue_max_size)
async with asyncio.TaskGroup() as tg:
# Create a task for the source feeder
tg.create_task(
self._run_source(
source_stream, source_config["stage"], source_config["channel"]
)
)
# Create tasks for all pipeline stages
for stage in self._stages.values():
tg.create_task(stage.run())
# Yield results from the sink
async for result in self._run_sink(
sink_config["stage"], sink_config["channel"]
):
yield result

View File

@ -0,0 +1,125 @@
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Set
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.pipeline.async_base_pipeline import DocumentTracker
_log = logging.getLogger(__name__)
@dataclass
class AsyncPageTracker:
"""Manages page backend lifecycle across documents"""
_doc_trackers: Dict[str, DocumentTracker] = field(default_factory=dict)
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
keep_images: bool = False
keep_backend: bool = False
async def register_document(
self, conv_res: ConversionResult, total_pages: int
) -> str:
"""Register a new document for tracking"""
async with self._lock:
doc_id = str(id(conv_res))
self._doc_trackers[doc_id] = DocumentTracker(
doc_id=doc_id, total_pages=total_pages, conv_result=conv_res
)
return doc_id
async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None:
"""Track when a page backend is loaded"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id in self._doc_trackers and page._backend is not None:
self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend
async def track_page_completion(
self, page: Page, conv_res: ConversionResult
) -> bool:
"""Track page completion and cleanup when all pages done"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id not in self._doc_trackers:
_log.warning(f"Document {doc_id} not registered for tracking")
return False
tracker = self._doc_trackers[doc_id]
tracker.processed_pages += 1
# Clear this page's image cache if needed
if not self.keep_images:
page._image_cache = {}
# If all pages from this document are processed, cleanup
if tracker.processed_pages == tracker.total_pages:
await self._cleanup_document_resources(tracker)
del self._doc_trackers[doc_id]
return True # Document is complete
return False # Document is not yet complete
async def _cleanup_document_resources(self, tracker: DocumentTracker) -> None:
"""Cleanup all resources for a completed document"""
if not self.keep_backend:
# Unload all page backends for this document
for page_no, backend in tracker.page_backends.items():
if backend is not None:
try:
# Run unload in thread to avoid blocking
await asyncio.to_thread(backend.unload)
except Exception as e:
_log.warning(
f"Failed to unload backend for page {page_no}: {e}"
)
tracker.page_backends.clear()
_log.debug(f"Cleaned up resources for document {tracker.doc_id}")
async def cleanup_all(self) -> None:
"""Cleanup all tracked documents - for shutdown"""
async with self._lock:
for tracker in self._doc_trackers.values():
await self._cleanup_document_resources(tracker)
self._doc_trackers.clear()
@dataclass
class ConversionResultAccumulator:
"""Accumulates updates to ConversionResult without immediate mutation"""
_updates: Dict[str, Dict] = field(default_factory=dict) # doc_id -> updates
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
async def accumulate_page_result(
self, page_no: int, conv_res: ConversionResult, updates: Dict
) -> None:
"""Accumulate updates for later application"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id not in self._updates:
self._updates[doc_id] = {}
if page_no not in self._updates[doc_id]:
self._updates[doc_id][page_no] = {}
self._updates[doc_id][page_no].update(updates)
async def flush_to_conv_res(self, conv_res: ConversionResult) -> None:
"""Apply all accumulated updates atomically"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id in self._updates:
# Apply updates
for page_no, updates in self._updates[doc_id].items():
# Find the page and apply updates
for page in conv_res.pages:
if page.page_no == page_no:
for key, value in updates.items():
setattr(page, key, value)
break
del self._updates[doc_id]

273
docling/pipeline/stages.py Normal file
View File

@ -0,0 +1,273 @@
from __future__ import annotations
import asyncio
import time
from collections import defaultdict
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
from docling.datamodel.document import ConversionResult, InputDocument, Page
from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem
class SourceStage(PipelineStage):
"""A placeholder stage to represent the entry point of the graph."""
async def run(self) -> None:
# This stage is driven by the GraphRunner's _run_source method
# and does not have its own execution loop.
pass
class SinkStage(PipelineStage):
"""A placeholder stage to represent the exit point of the graph."""
async def run(self) -> None:
# This stage is read by the GraphRunner's _run_sink method
# and does not have its own execution loop.
pass
class ExtractionStage(PipelineStage):
"""Extracts pages from documents and tracks them."""
def __init__(
self,
name: str,
page_tracker: Any,
max_concurrent_extractions: int,
):
super().__init__(name)
self.page_tracker = page_tracker
self.semaphore = asyncio.Semaphore(max_concurrent_extractions)
self.input_channel = "in"
self.output_channel = "out"
self.failure_channel = "fail"
async def _extract_page(
self, page_no: int, conv_res: ConversionResult
) -> StreamItem | None:
"""Coroutine to extract a single page."""
try:
async with self.semaphore:
page = Page(page_no=page_no)
conv_res.pages.append(page)
page._backend = await self.loop.run_in_executor(
None, conv_res.input._backend.load_page, page_no
)
if page._backend and page._backend.is_valid():
page.size = page._backend.get_size()
await self.page_tracker.track_page_loaded(page, conv_res)
return StreamItem(
payload=page, conv_res_id=id(conv_res), conv_res=conv_res
)
except Exception:
# In case of page-level error, we might log it but continue
# For now, we don't propagate failure here, but in the document level
pass
return None
async def _process_document(self, in_doc: InputDocument):
"""Processes a single document, extracting all its pages."""
conv_res = ConversionResult(input=in_doc)
try:
from docling.backend.pdf_backend import PdfDocumentBackend
if not isinstance(in_doc._backend, PdfDocumentBackend):
raise TypeError("Backend is not a valid PdfDocumentBackend")
total_pages = in_doc.page_count
await self.page_tracker.register_document(conv_res, total_pages)
start_page, end_page = conv_res.input.limits.page_range
page_indices_to_extract = [
i for i in range(total_pages) if (start_page - 1) <= i <= (end_page - 1)
]
tasks = [
self.loop.create_task(self._extract_page(i, conv_res))
for i in page_indices_to_extract
]
pages_extracted = await asyncio.gather(*tasks)
await self._send_to_outputs(
self.output_channel, [p for p in pages_extracted if p]
)
except Exception:
conv_res.status = "FAILURE"
await self._send_to_outputs(self.failure_channel, [conv_res])
async def run(self) -> None:
"""Main loop to consume documents and launch extraction tasks."""
q_in = self.input_queues[self.input_channel]
while True:
doc = await q_in.get()
if doc is STOP_SENTINEL:
await self._signal_downstream_completion()
break
await self._process_document(doc)
class PageProcessorStage(PipelineStage):
"""Applies a synchronous, 1-to-1 processing function to each page."""
def __init__(self, name: str, model: Any):
super().__init__(name)
self.model = model
self.input_channel = "in"
self.output_channel = "out"
async def run(self) -> None:
q_in = self.input_queues[self.input_channel]
while True:
item = await q_in.get()
if item is STOP_SENTINEL:
await self._signal_downstream_completion()
break
# The model call is sync, run in thread to avoid blocking event loop
processed_page = await self.loop.run_in_executor(
None, lambda: next(iter(self.model(item.conv_res, [item.payload])))
)
item.payload = processed_page
await self._send_to_outputs(self.output_channel, [item])
class BatchProcessorStage(PipelineStage):
"""Batches items and applies a synchronous model to the batch."""
def __init__(
self,
name: str,
model: Any,
batch_size: int,
batch_timeout: float,
):
super().__init__(name)
self.model = model
self.batch_size = batch_size
self.batch_timeout = batch_timeout
self.input_channel = "in"
self.output_channel = "out"
async def _collect_batch(self, q_in: asyncio.Queue) -> List[StreamItem] | None:
"""Collects a batch of items from the input queue with a timeout."""
try:
# Wait for the first item without a timeout
first_item = await q_in.get()
if first_item is STOP_SENTINEL:
return None # End of stream
except asyncio.CancelledError:
return None
batch = [first_item]
start_time = self.loop.time()
while len(batch) < self.batch_size:
timeout = self.batch_timeout - (self.loop.time() - start_time)
if timeout <= 0:
break
try:
item = await asyncio.wait_for(q_in.get(), timeout)
if item is STOP_SENTINEL:
# Put sentinel back for other potential consumers or the main loop
await q_in.put(STOP_SENTINEL)
break
batch.append(item)
except asyncio.TimeoutError:
break # Batching timeout reached
return batch
async def run(self) -> None:
q_in = self.input_queues[self.input_channel]
while True:
batch = await self._collect_batch(q_in)
if not batch: # This can be None or an empty list.
await self._signal_downstream_completion()
break
# Group pages by their original ConversionResult
grouped_by_doc = defaultdict(list)
for item in batch:
grouped_by_doc[item.conv_res_id].append(item)
processed_items = []
for conv_res_id, items in grouped_by_doc.items():
conv_res = items[0].conv_res
pages = [item.payload for item in items]
# The model call is sync, run in thread
processed_pages = await self.loop.run_in_executor(
None, lambda: list(self.model(conv_res, pages))
)
# Re-wrap the processed pages into StreamItems
for i, page in enumerate(processed_pages):
processed_items.append(
StreamItem(
payload=page,
conv_res_id=items[i].conv_res_id,
conv_res=items[i].conv_res,
)
)
await self._send_to_outputs(self.output_channel, processed_items)
class AggregationStage(PipelineStage):
"""Aggregates processed pages back into completed documents."""
def __init__(
self,
name: str,
page_tracker: Any,
finalizer_func: Callable[[ConversionResult], Coroutine],
):
super().__init__(name)
self.page_tracker = page_tracker
self.finalizer_func = finalizer_func
self.success_channel = "in"
self.failure_channel = "fail"
self.output_channel = "out"
async def run(self) -> None:
success_q = self.input_queues[self.success_channel]
failure_q = self.input_queues.get(self.failure_channel)
doc_completers: Dict[int, asyncio.Future] = {}
async def handle_successes():
while True:
item = await success_q.get()
if item is STOP_SENTINEL:
break
is_doc_complete = await self.page_tracker.track_page_completion(
item.payload, item.conv_res
)
if is_doc_complete:
await self.finalizer_func(item.conv_res)
await self._send_to_outputs(self.output_channel, [item.conv_res])
if item.conv_res_id in doc_completers:
doc_completers.pop(item.conv_res_id).set_result(True)
async def handle_failures():
if failure_q is None:
return # No failure channel, nothing to do
while True:
failed_res = await failure_q.get()
if failed_res is STOP_SENTINEL:
break
await self._send_to_outputs(self.output_channel, [failed_res])
if id(failed_res) in doc_completers:
doc_completers.pop(id(failed_res)).set_result(True)
# Create tasks only for channels that exist
tasks = [handle_successes()]
if failure_q is not None:
tasks.append(handle_failures())
await asyncio.gather(*tasks)
await self._signal_downstream_completion()