diff --git a/docling/pipeline/async_standard_pdf_pipeline.py b/docling/pipeline/async_standard_pdf_pipeline.py index e5a1bc5b..47c25a47 100644 --- a/docling/pipeline/async_standard_pdf_pipeline.py +++ b/docling/pipeline/async_standard_pdf_pipeline.py @@ -1,11 +1,7 @@ import asyncio import logging -import time -from collections import defaultdict -from dataclasses import dataclass, field from typing import Any, AsyncIterable, Dict, List, Optional, Tuple -from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import ConversionStatus, Page from docling.datamodel.document import ConversionResult, InputDocument from docling.datamodel.pipeline_options import AsyncPdfPipelineOptions @@ -17,8 +13,6 @@ from docling.models.document_picture_classifier import ( DocumentPictureClassifierOptions, ) from docling.models.factories import get_ocr_factory, get_picture_description_factory - -# Import the same models used by StandardPdfPipeline from docling.models.layout_model import LayoutModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_preprocessing_model import ( @@ -28,82 +22,36 @@ from docling.models.page_preprocessing_model import ( from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.table_structure_model import TableStructureModel from docling.pipeline.async_base_pipeline import AsyncPipeline -from docling.pipeline.resource_manager import ( - AsyncPageTracker, - ConversionResultAccumulator, +from docling.pipeline.graph import GraphRunner +from docling.pipeline.resource_manager import AsyncPageTracker +from docling.pipeline.stages import ( + AggregationStage, + BatchProcessorStage, + ExtractionStage, + PageProcessorStage, + SinkStage, + SourceStage, ) from docling.utils.profiling import ProfilingScope, TimeRecorder _log = logging.getLogger(__name__) -@dataclass -class PageBatch: - """Represents a batch of pages to process through models""" - - pages: List[Page] = field(default_factory=list) - conv_results: List[ConversionResult] = field(default_factory=list) - start_time: float = field(default_factory=time.time) - - -@dataclass -class QueueTerminator: - """Sentinel value for proper queue termination tracking""" - - stage: str - error: Optional[Exception] = None - - -class UpstreamAwareQueue(asyncio.Queue): - """Queue that tracks upstream completion to optimize batch processing""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.upstream_complete = asyncio.Event() - self.items_in_flight = 0 - - async def put(self, item): - """Track items and upstream completion signals""" - if isinstance(item, QueueTerminator): - self.upstream_complete.set() - else: - self.items_in_flight += 1 - await super().put(item) - - async def get(self): - """Track when items are consumed""" - item = await super().get() - if not isinstance(item, QueueTerminator): - self.items_in_flight -= 1 - return item - - def should_stop_waiting_for_batch(self) -> bool: - """Determine if we should stop waiting and process current batch""" - return ( - self.upstream_complete.is_set() - and self.items_in_flight == 0 - and self.empty() - ) - - class AsyncStandardPdfPipeline(AsyncPipeline): - """Async pipeline implementation with cross-document batching using structured concurrency""" + """ + An async, graph-based pipeline for processing PDFs with cross-document batching. + """ def __init__(self, pipeline_options: AsyncPdfPipelineOptions): super().__init__(pipeline_options) self.pipeline_options: AsyncPdfPipelineOptions = pipeline_options - - # Resource management self.page_tracker = AsyncPageTracker( keep_images=self._should_keep_images(), keep_backend=self._should_keep_backend(), ) - - # Initialize models (same as StandardPdfPipeline) self._initialize_models() def _should_keep_images(self) -> bool: - """Determine if images should be kept (same logic as StandardPdfPipeline)""" return ( self.pipeline_options.generate_page_images or self.pipeline_options.generate_picture_images @@ -111,7 +59,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline): ) def _should_keep_backend(self) -> bool: - """Determine if backend should be kept""" return ( self.pipeline_options.do_formula_enrichment or self.pipeline_options.do_code_enrichment @@ -120,36 +67,26 @@ class AsyncStandardPdfPipeline(AsyncPipeline): ) def _initialize_models(self): - """Initialize all models (matching StandardPdfPipeline)""" artifacts_path = self._get_artifacts_path() - self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions()) - - # Build pipeline stages self.preprocessing_model = PagePreprocessingModel( options=PagePreprocessingOptions( images_scale=self.pipeline_options.images_scale, ) ) - self.ocr_model = self._get_ocr_model(artifacts_path) - self.layout_model = LayoutModel( artifacts_path=artifacts_path, accelerator_options=self.pipeline_options.accelerator_options, options=self.pipeline_options.layout_options, ) - self.table_model = TableStructureModel( enabled=self.pipeline_options.do_table_structure, artifacts_path=artifacts_path, options=self.pipeline_options.table_structure_options, accelerator_options=self.pipeline_options.accelerator_options, ) - self.assemble_model = PageAssembleModel(options=PageAssembleOptions()) - - # Enrichment models self.code_formula_model = CodeFormulaModel( enabled=self.pipeline_options.do_code_enrichment or self.pipeline_options.do_formula_enrichment, @@ -160,20 +97,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline): ), accelerator_options=self.pipeline_options.accelerator_options, ) - self.picture_classifier = DocumentPictureClassifier( enabled=self.pipeline_options.do_picture_classification, artifacts_path=artifacts_path, options=DocumentPictureClassifierOptions(), accelerator_options=self.pipeline_options.accelerator_options, ) - self.picture_description_model = self._get_picture_description_model( artifacts_path ) def _get_artifacts_path(self) -> Optional[str]: - """Get artifacts path (same as StandardPdfPipeline)""" from pathlib import Path artifacts_path = None @@ -190,7 +124,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline): return artifacts_path def _get_ocr_model(self, artifacts_path: Optional[str] = None) -> BaseOcrModel: - """Get OCR model (same as StandardPdfPipeline)""" factory = get_ocr_factory( allow_external_plugins=self.pipeline_options.allow_external_plugins ) @@ -202,7 +135,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline): ) def _get_picture_description_model(self, artifacts_path: Optional[str] = None): - """Get picture description model (same as StandardPdfPipeline)""" factory = get_picture_description_factory( allow_external_plugins=self.pipeline_options.allow_external_plugins ) @@ -217,607 +149,116 @@ class AsyncStandardPdfPipeline(AsyncPipeline): async def execute_stream( self, input_docs: AsyncIterable[InputDocument] ) -> AsyncIterable[ConversionResult]: - """Main async processing with structured concurrency and proper exception handling""" - # Create queues for pipeline stages - page_queue = UpstreamAwareQueue( - maxsize=self.pipeline_options.extraction_queue_size - ) - completed_queue = UpstreamAwareQueue() - completed_docs = UpstreamAwareQueue() + """Main async processing driven by a pipeline graph.""" + stages = [ + SourceStage("source"), + ExtractionStage( + "extractor", + self.page_tracker, + self.pipeline_options.max_concurrent_extractions, + ), + PageProcessorStage("preprocessor", self.preprocessing_model), + BatchProcessorStage( + "ocr", + self.ocr_model, + self.pipeline_options.ocr_batch_size, + self.pipeline_options.batch_timeout_seconds, + ), + BatchProcessorStage( + "layout", + self.layout_model, + self.pipeline_options.layout_batch_size, + self.pipeline_options.batch_timeout_seconds, + ), + BatchProcessorStage( + "table", + self.table_model, + self.pipeline_options.table_batch_size, + self.pipeline_options.batch_timeout_seconds, + ), + PageProcessorStage("assembler", self.assemble_model), + AggregationStage("aggregator", self.page_tracker, self._finalize_document), + SinkStage("sink"), + ] - # Track active documents for proper termination - doc_tracker = {"active_docs": 0, "extraction_done": False} - doc_lock = asyncio.Lock() + edges = [ + # Main processing path + { + "from_stage": "source", + "from_output": "out", + "to_stage": "extractor", + "to_input": "in", + }, + { + "from_stage": "extractor", + "from_output": "out", + "to_stage": "preprocessor", + "to_input": "in", + }, + { + "from_stage": "preprocessor", + "from_output": "out", + "to_stage": "ocr", + "to_input": "in", + }, + { + "from_stage": "ocr", + "from_output": "out", + "to_stage": "layout", + "to_input": "in", + }, + { + "from_stage": "layout", + "from_output": "out", + "to_stage": "table", + "to_input": "in", + }, + { + "from_stage": "table", + "from_output": "out", + "to_stage": "assembler", + "to_input": "in", + }, + { + "from_stage": "assembler", + "from_output": "out", + "to_stage": "aggregator", + "to_input": "in", + }, + # Failure path + { + "from_stage": "extractor", + "from_output": "fail", + "to_stage": "aggregator", + "to_input": "fail", + }, + # Final output + { + "from_stage": "aggregator", + "from_output": "out", + "to_stage": "sink", + "to_input": "in", + }, + ] - # Create exception event for coordinated shutdown - exception_event = asyncio.Event() - - async def track_document_start(): - async with doc_lock: - doc_tracker["active_docs"] += 1 - - async def track_document_complete(): - async with doc_lock: - doc_tracker["active_docs"] -= 1 - if doc_tracker["extraction_done"] and doc_tracker["active_docs"] == 0: - # All documents completed - await completed_docs.put(None) + runner = GraphRunner(stages, edges) + source_config = {"stage": "source", "channel": "out"} + sink_config = {"stage": "sink", "channel": "in"} try: - async with asyncio.TaskGroup() as tg: - # Start all tasks - tg.create_task( - self._extract_documents_wrapper( - input_docs, - page_queue, - track_document_start, - exception_event, - doc_tracker, - doc_lock, - ) - ) - tg.create_task( - self._process_pages_wrapper( - page_queue, completed_queue, exception_event - ) - ) - tg.create_task( - self._aggregate_results_wrapper( - completed_queue, - completed_docs, - track_document_complete, - exception_event, - ) - ) - - # Yield results as they complete - async for result in self._yield_results( - completed_docs, exception_event - ): - yield result - + async for result in runner.run( + input_docs, + source_config, + sink_config, + self.pipeline_options.extraction_queue_size, + ): + yield result except* Exception as eg: - # Handle exception group from TaskGroup _log.error(f"Pipeline failed with exceptions: {eg.exceptions}") - # Re-raise the first exception raise (eg.exceptions[0] if eg.exceptions else RuntimeError("Unknown error")) finally: - # Ensure cleanup await self.page_tracker.cleanup_all() - async def _extract_documents_wrapper( - self, - input_docs: AsyncIterable[InputDocument], - page_queue: UpstreamAwareQueue, - track_document_start, - exception_event: asyncio.Event, - doc_tracker: Dict[str, Any], - doc_lock: asyncio.Lock, - ): - """Wrapper for document extraction with exception handling""" - try: - await self._extract_documents_safe( - input_docs, - page_queue, - track_document_start, - exception_event, - ) - except Exception: - exception_event.set() - raise - finally: - async with doc_lock: - doc_tracker["extraction_done"] = True - # Send termination signal - await page_queue.put(QueueTerminator("extraction")) - - async def _process_pages_wrapper( - self, - page_queue: UpstreamAwareQueue, - completed_queue: UpstreamAwareQueue, - exception_event: asyncio.Event, - ): - """Wrapper for page processing with exception handling""" - try: - await self._process_pages_safe(page_queue, completed_queue, exception_event) - except Exception: - exception_event.set() - raise - finally: - # Send termination signal - await completed_queue.put(QueueTerminator("processing")) - - async def _aggregate_results_wrapper( - self, - completed_queue: UpstreamAwareQueue, - completed_docs: UpstreamAwareQueue, - track_document_complete, - exception_event: asyncio.Event, - ): - """Wrapper for result aggregation with exception handling""" - try: - await self._aggregate_results_safe( - completed_queue, - completed_docs, - track_document_complete, - exception_event, - ) - except Exception: - exception_event.set() - raise - - async def _yield_results( - self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event - ): - """Yield results as they complete""" - while True: - if exception_event.is_set(): - break - - try: - result = await asyncio.wait_for(completed_docs.get(), timeout=1.0) - if result is None: - break - yield result - except asyncio.TimeoutError: - continue - except Exception: - exception_event.set() - raise - - async def _extract_documents_safe( - self, - input_docs: AsyncIterable[InputDocument], - page_queue: UpstreamAwareQueue, - track_document_start, - exception_event: asyncio.Event, - ) -> None: - """Extract pages from documents with exception handling""" - async for in_doc in input_docs: - if exception_event.is_set(): - break - - await track_document_start() - conv_res = ConversionResult(input=in_doc) - - # Validate backend - if not isinstance(conv_res.input._backend, PdfDocumentBackend): - conv_res.status = ConversionStatus.FAILURE - await page_queue.put((None, conv_res)) # Signal failed document - continue - - try: - # Initialize document - total_pages = conv_res.input.page_count - await self.page_tracker.register_document(conv_res, total_pages) - - # Extract pages with limited concurrency - semaphore = asyncio.Semaphore( - self.pipeline_options.max_concurrent_extractions - ) - - async def extract_page(page_no: int): - if exception_event.is_set(): - return - - async with semaphore: - # Create page - page = Page(page_no=page_no) - conv_res.pages.append(page) - - # Initialize page backend - page._backend = await asyncio.to_thread( - conv_res.input._backend.load_page, page_no - ) - - if page._backend is not None and page._backend.is_valid(): - page.size = page._backend.get_size() - await self.page_tracker.track_page_loaded(page, conv_res) - - # Send to processing queue - await page_queue.put((page, conv_res)) - - # Extract all pages concurrently - async with asyncio.TaskGroup() as tg: - for i in range(total_pages): - if exception_event.is_set(): - break - start_page, end_page = conv_res.input.limits.page_range - if (start_page - 1) <= i <= (end_page - 1): - tg.create_task(extract_page(i)) - - except Exception as e: - _log.error(f"Failed to extract document {in_doc.file.name}: {e}") - conv_res.status = ConversionStatus.FAILURE - # Signal document failure - await page_queue.put((None, conv_res)) - raise - - async def _process_pages_safe( - self, - page_queue: UpstreamAwareQueue, - completed_queue: UpstreamAwareQueue, - exception_event: asyncio.Event, - ) -> None: - """Process pages through model pipeline with proper termination""" - # Process batches through each model stage - preprocessing_queue = UpstreamAwareQueue() - ocr_queue = UpstreamAwareQueue() - layout_queue = UpstreamAwareQueue() - table_queue = UpstreamAwareQueue() - assemble_queue = UpstreamAwareQueue() - - # Start processing stages using TaskGroup - async with asyncio.TaskGroup() as tg: - # Preprocessing stage - tg.create_task( - self._batch_process_stage_safe( - page_queue, - preprocessing_queue, - self._preprocess_batch, - 1, - 0, # No batching for preprocessing - "preprocessing", - exception_event, - ) - ) - - # OCR stage - tg.create_task( - self._batch_process_stage_safe( - preprocessing_queue, - ocr_queue, - self._ocr_batch, - self.pipeline_options.ocr_batch_size, - self.pipeline_options.batch_timeout_seconds, - "ocr", - exception_event, - ) - ) - - # Layout stage - tg.create_task( - self._batch_process_stage_safe( - ocr_queue, - layout_queue, - self._layout_batch, - self.pipeline_options.layout_batch_size, - self.pipeline_options.batch_timeout_seconds, - "layout", - exception_event, - ) - ) - - # Table stage - tg.create_task( - self._batch_process_stage_safe( - layout_queue, - table_queue, - self._table_batch, - self.pipeline_options.table_batch_size, - self.pipeline_options.batch_timeout_seconds, - "table", - exception_event, - ) - ) - - # Assembly stage - tg.create_task( - self._batch_process_stage_safe( - table_queue, - assemble_queue, - self._assemble_batch, - 1, - 0, # No batching for assembly - "assembly", - exception_event, - ) - ) - - # Finalization stage - tg.create_task( - self._finalize_pages_safe( - assemble_queue, completed_queue, exception_event - ) - ) - - async def _batch_process_stage_safe( - self, - input_queue: UpstreamAwareQueue, - output_queue: UpstreamAwareQueue, - process_func, - batch_size: int, - timeout: float, - stage_name: str, - exception_event: asyncio.Event, - ) -> None: - """Generic batch processing stage with proper termination handling""" - batch = PageBatch() - - try: - while not exception_event.is_set(): - # Collect batch - try: - # Get first item or wait for timeout - if not batch.pages: - item = await input_queue.get() - - # Check for termination - if isinstance(item, QueueTerminator): - # Propagate termination signal - await output_queue.put(item) - break - - # Handle failed document signal - if item[0] is None: - # Pass through failure signal - await output_queue.put(item) - continue - - batch.pages.append(item[0]) - batch.conv_results.append(item[1]) - - # Try to fill batch up to batch_size or until upstream exhausted - while len(batch.pages) < batch_size: - # Check if upstream is exhausted and we should process immediately - if input_queue.should_stop_waiting_for_batch(): - break - - # Calculate remaining time for batch timeout - remaining_time = timeout - (time.time() - batch.start_time) - - # If upstream is complete, use minimal timeout to drain quickly - if input_queue.upstream_complete.is_set(): - remaining_time = min(remaining_time, 0.05) - - if remaining_time <= 0: - break - - try: - item = await asyncio.wait_for( - input_queue.get(), timeout=remaining_time - ) - - # Check for termination - if isinstance(item, QueueTerminator): - # Put it back and process current batch - await input_queue.put(item) - break - - # Handle failed document signal - if item[0] is None: - # Put it back and process current batch - await input_queue.put(item) - break - - batch.pages.append(item[0]) - batch.conv_results.append(item[1]) - except asyncio.TimeoutError: - break - - # Process batch - if batch.pages: - processed = await process_func(batch) - - # Send results to output queue - for page, conv_res in processed: - await output_queue.put((page, conv_res)) - - # Clear batch - batch = PageBatch() - - except Exception as e: - _log.error(f"Error in {stage_name} batch processing: {e}") - # Send failed items downstream - for page, conv_res in zip(batch.pages, batch.conv_results): - await output_queue.put((page, conv_res)) - batch = PageBatch() - raise - - except Exception as e: - # Set exception event and propagate termination - exception_event.set() - await output_queue.put(QueueTerminator(stage_name, error=e)) - raise - - async def _preprocess_batch( - self, batch: PageBatch - ) -> List[Tuple[Page, ConversionResult]]: - """Preprocess pages (no actual batching needed)""" - results = [] - for page, conv_res in zip(batch.pages, batch.conv_results): - processed_page = await asyncio.to_thread( - lambda: next(iter(self.preprocessing_model(conv_res, [page]))) - ) - results.append((processed_page, conv_res)) - return results - - async def _ocr_batch(self, batch: PageBatch) -> List[Tuple[Page, ConversionResult]]: - """Process OCR in batch""" - # Group by conversion result for proper context - grouped = defaultdict(list) - for page, conv_res in zip(batch.pages, batch.conv_results): - grouped[id(conv_res)].append(page) - - results = [] - for conv_res_id, pages in grouped.items(): - # Find the conv_res - conv_res = next( - cr - for p, cr in zip(batch.pages, batch.conv_results) - if id(cr) == conv_res_id - ) - - # Process batch through OCR model - processed_pages = await asyncio.to_thread( - lambda: list(self.ocr_model(conv_res, pages)) - ) - - for page in processed_pages: - results.append((page, conv_res)) - - return results - - async def _layout_batch( - self, batch: PageBatch - ) -> List[Tuple[Page, ConversionResult]]: - """Process layout in batch""" - # Similar batching as OCR - grouped = defaultdict(list) - for page, conv_res in zip(batch.pages, batch.conv_results): - grouped[id(conv_res)].append(page) - - results = [] - for conv_res_id, pages in grouped.items(): - conv_res = next( - cr - for p, cr in zip(batch.pages, batch.conv_results) - if id(cr) == conv_res_id - ) - - processed_pages = await asyncio.to_thread( - lambda: list(self.layout_model(conv_res, pages)) - ) - - for page in processed_pages: - results.append((page, conv_res)) - - return results - - async def _table_batch( - self, batch: PageBatch - ) -> List[Tuple[Page, ConversionResult]]: - """Process tables in batch""" - grouped = defaultdict(list) - for page, conv_res in zip(batch.pages, batch.conv_results): - grouped[id(conv_res)].append(page) - - results = [] - for conv_res_id, pages in grouped.items(): - conv_res = next( - cr - for p, cr in zip(batch.pages, batch.conv_results) - if id(cr) == conv_res_id - ) - - processed_pages = await asyncio.to_thread( - lambda: list(self.table_model(conv_res, pages)) - ) - - for page in processed_pages: - results.append((page, conv_res)) - - return results - - async def _assemble_batch( - self, batch: PageBatch - ) -> List[Tuple[Page, ConversionResult]]: - """Assemble pages (no actual batching needed)""" - results = [] - for page, conv_res in zip(batch.pages, batch.conv_results): - assembled_page = await asyncio.to_thread( - lambda: next(iter(self.assemble_model(conv_res, [page]))) - ) - results.append((assembled_page, conv_res)) - return results - - async def _finalize_pages_safe( - self, - input_queue: UpstreamAwareQueue, - output_queue: UpstreamAwareQueue, - exception_event: asyncio.Event, - ) -> None: - """Finalize pages and track completion with proper termination""" - try: - while not exception_event.is_set(): - item = await input_queue.get() - - # Check for termination - if isinstance(item, QueueTerminator): - # Propagate termination signal - await output_queue.put(item) - break - - # Handle failed document signal - if item[0] is None: - # Pass through failure signal - await output_queue.put(item) - continue - - page, conv_res = item - - # Track page completion for resource cleanup - await self.page_tracker.track_page_completion(page, conv_res) - - # Send to output - await output_queue.put((page, conv_res)) - - except Exception as e: - exception_event.set() - await output_queue.put(QueueTerminator("finalization", error=e)) - raise - - async def _aggregate_results_safe( - self, - completed_queue: UpstreamAwareQueue, - completed_docs: UpstreamAwareQueue, - track_document_complete, - exception_event: asyncio.Event, - ) -> None: - """Aggregate completed pages into documents with proper termination""" - doc_pages = defaultdict(list) - failed_docs = set() - - try: - while not exception_event.is_set(): - item = await completed_queue.get() - - # Check for termination - if isinstance(item, QueueTerminator): - # Finalize any remaining documents - for conv_res_id, pages in doc_pages.items(): - if conv_res_id not in failed_docs: - # Find conv_res from first page - conv_res = pages[0][1] - await self._finalize_document(conv_res) - await completed_docs.put(conv_res) - await track_document_complete() - break - - # Handle failed document signal - if item[0] is None: - conv_res = item[1] - doc_id = id(conv_res) - failed_docs.add(doc_id) - # Send failed document immediately - await completed_docs.put(conv_res) - await track_document_complete() - continue - - page, conv_res = item - doc_id = id(conv_res) - - if doc_id not in failed_docs: - doc_pages[doc_id].append((page, conv_res)) - - # Check if document is complete - if len(doc_pages[doc_id]) == len(conv_res.pages): - await self._finalize_document(conv_res) - await completed_docs.put(conv_res) - await track_document_complete() - del doc_pages[doc_id] - - except Exception: - exception_event.set() - # Try to send any completed documents before failing - for conv_res_id, pages in doc_pages.items(): - if conv_res_id not in failed_docs and pages: - conv_res = pages[0][1] - conv_res.status = ConversionStatus.PARTIAL_SUCCESS - await completed_docs.put(conv_res) - await track_document_complete() - raise - async def _finalize_document(self, conv_res: ConversionResult) -> None: """Finalize a complete document (same as StandardPdfPipeline._assemble_document)""" # This matches the logic from StandardPdfPipeline diff --git a/docling/pipeline/graph.py b/docling/pipeline/graph.py new file mode 100644 index 00000000..0d9827a1 --- /dev/null +++ b/docling/pipeline/graph.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, AsyncIterable, Dict, List, Literal, Optional + +# Sentinel to signal stream completion +STOP_SENTINEL = object() + + +@dataclass(slots=True) +class StreamItem: + """ + A wrapper for data flowing through the pipeline, maintaining a link + to the original conversion result context. + """ + + payload: Any + conv_res_id: int + conv_res: Any # Opaque reference to ConversionResult + + +class PipelineStage(ABC): + """A single, encapsulated step in a processing pipeline graph.""" + + def __init__(self, name: str): + self.name = name + self.input_queues: Dict[str, asyncio.Queue] = {} + self.output_queues: Dict[str, List[asyncio.Queue]] = {} + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @abstractmethod + async def run(self) -> None: + """ + The core execution logic for the stage. This method is responsible for + consuming from input queues, processing data, and putting results into + output queues. + """ + + async def _send_to_outputs(self, channel: str, items: List[StreamItem] | List[Any]): + """Helper to send processed items to all connected output queues.""" + if channel in self.output_queues: + for queue in self.output_queues[channel]: + for item in items: + await queue.put(item) + + async def _signal_downstream_completion(self): + """Signal that this stage is done processing to all output channels.""" + for channel_queues in self.output_queues.values(): + for queue in channel_queues: + await queue.put(STOP_SENTINEL) + + @property + def loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None: + self._loop = asyncio.get_running_loop() + return self._loop + + +class GraphRunner: + """Connects stages and runs the pipeline graph.""" + + def __init__(self, stages: List[PipelineStage], edges: List[Dict[str, str]]): + self._stages = {s.name: s for s in stages} + self._edges = edges + + def _wire_graph(self, queue_max_size: int): + """Create queues for edges and connect them to stage inputs and outputs.""" + for edge in self._edges: + from_stage, from_output = edge["from_stage"], edge["from_output"] + to_stage, to_input = edge["to_stage"], edge["to_input"] + + queue = asyncio.Queue(maxsize=queue_max_size) + + # Connect to source stage's output + self._stages[from_stage].output_queues.setdefault(from_output, []).append( + queue + ) + + # Connect to destination stage's input + self._stages[to_stage].input_queues[to_input] = queue + + async def _run_source( + self, + source_stream: AsyncIterable[Any], + source_stage: str, + source_channel: str, + ): + """Feed the graph from an external async iterable.""" + output_queues = self._stages[source_stage].output_queues.get(source_channel, []) + async for item in source_stream: + for queue in output_queues: + await queue.put(item) + # Signal completion to all downstream queues + for queue in output_queues: + await queue.put(STOP_SENTINEL) + + async def _run_sink(self, sink_stage: str, sink_channel: str) -> AsyncIterable[Any]: + """Yield results from the graph's final output queue.""" + queue = self._stages[sink_stage].input_queues[sink_channel] + while True: + item = await queue.get() + if item is STOP_SENTINEL: + break + yield item + await queue.put(STOP_SENTINEL) # Allow other sinks to terminate + + async def run( + self, + source_stream: AsyncIterable, + source_config: Dict[str, str], + sink_config: Dict[str, str], + queue_max_size: int = 32, + ) -> AsyncIterable: + """ + Executes the entire pipeline graph. + + Args: + source_stream: The initial async iterable to feed the graph. + source_config: Dictionary with "stage" and "channel" for the entry point. + sink_config: Dictionary with "stage" and "channel" for the exit point. + queue_max_size: The max size for the internal asyncio.Queues. + """ + self._wire_graph(queue_max_size) + + async with asyncio.TaskGroup() as tg: + # Create a task for the source feeder + tg.create_task( + self._run_source( + source_stream, source_config["stage"], source_config["channel"] + ) + ) + + # Create tasks for all pipeline stages + for stage in self._stages.values(): + tg.create_task(stage.run()) + + # Yield results from the sink + async for result in self._run_sink( + sink_config["stage"], sink_config["channel"] + ): + yield result diff --git a/docling/pipeline/resource_manager.py b/docling/pipeline/resource_manager.py new file mode 100644 index 00000000..4e1d5e2f --- /dev/null +++ b/docling/pipeline/resource_manager.py @@ -0,0 +1,125 @@ +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Dict, Optional, Set + +from docling.datamodel.base_models import Page +from docling.datamodel.document import ConversionResult +from docling.pipeline.async_base_pipeline import DocumentTracker + +_log = logging.getLogger(__name__) + + +@dataclass +class AsyncPageTracker: + """Manages page backend lifecycle across documents""" + + _doc_trackers: Dict[str, DocumentTracker] = field(default_factory=dict) + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + keep_images: bool = False + keep_backend: bool = False + + async def register_document( + self, conv_res: ConversionResult, total_pages: int + ) -> str: + """Register a new document for tracking""" + async with self._lock: + doc_id = str(id(conv_res)) + self._doc_trackers[doc_id] = DocumentTracker( + doc_id=doc_id, total_pages=total_pages, conv_result=conv_res + ) + return doc_id + + async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None: + """Track when a page backend is loaded""" + async with self._lock: + doc_id = str(id(conv_res)) + if doc_id in self._doc_trackers and page._backend is not None: + self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend + + async def track_page_completion( + self, page: Page, conv_res: ConversionResult + ) -> bool: + """Track page completion and cleanup when all pages done""" + async with self._lock: + doc_id = str(id(conv_res)) + if doc_id not in self._doc_trackers: + _log.warning(f"Document {doc_id} not registered for tracking") + return False + + tracker = self._doc_trackers[doc_id] + tracker.processed_pages += 1 + + # Clear this page's image cache if needed + if not self.keep_images: + page._image_cache = {} + + # If all pages from this document are processed, cleanup + if tracker.processed_pages == tracker.total_pages: + await self._cleanup_document_resources(tracker) + del self._doc_trackers[doc_id] + return True # Document is complete + + return False # Document is not yet complete + + async def _cleanup_document_resources(self, tracker: DocumentTracker) -> None: + """Cleanup all resources for a completed document""" + if not self.keep_backend: + # Unload all page backends for this document + for page_no, backend in tracker.page_backends.items(): + if backend is not None: + try: + # Run unload in thread to avoid blocking + await asyncio.to_thread(backend.unload) + except Exception as e: + _log.warning( + f"Failed to unload backend for page {page_no}: {e}" + ) + + tracker.page_backends.clear() + _log.debug(f"Cleaned up resources for document {tracker.doc_id}") + + async def cleanup_all(self) -> None: + """Cleanup all tracked documents - for shutdown""" + async with self._lock: + for tracker in self._doc_trackers.values(): + await self._cleanup_document_resources(tracker) + self._doc_trackers.clear() + + +@dataclass +class ConversionResultAccumulator: + """Accumulates updates to ConversionResult without immediate mutation""" + + _updates: Dict[str, Dict] = field(default_factory=dict) # doc_id -> updates + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + async def accumulate_page_result( + self, page_no: int, conv_res: ConversionResult, updates: Dict + ) -> None: + """Accumulate updates for later application""" + async with self._lock: + doc_id = str(id(conv_res)) + if doc_id not in self._updates: + self._updates[doc_id] = {} + + if page_no not in self._updates[doc_id]: + self._updates[doc_id][page_no] = {} + + self._updates[doc_id][page_no].update(updates) + + async def flush_to_conv_res(self, conv_res: ConversionResult) -> None: + """Apply all accumulated updates atomically""" + async with self._lock: + doc_id = str(id(conv_res)) + if doc_id in self._updates: + # Apply updates + for page_no, updates in self._updates[doc_id].items(): + # Find the page and apply updates + for page in conv_res.pages: + if page.page_no == page_no: + for key, value in updates.items(): + setattr(page, key, value) + break + + del self._updates[doc_id] diff --git a/docling/pipeline/stages.py b/docling/pipeline/stages.py new file mode 100644 index 00000000..b07ce33a --- /dev/null +++ b/docling/pipeline/stages.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import asyncio +import time +from collections import defaultdict +from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List + +from docling.datamodel.document import ConversionResult, InputDocument, Page +from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem + + +class SourceStage(PipelineStage): + """A placeholder stage to represent the entry point of the graph.""" + + async def run(self) -> None: + # This stage is driven by the GraphRunner's _run_source method + # and does not have its own execution loop. + pass + + +class SinkStage(PipelineStage): + """A placeholder stage to represent the exit point of the graph.""" + + async def run(self) -> None: + # This stage is read by the GraphRunner's _run_sink method + # and does not have its own execution loop. + pass + + +class ExtractionStage(PipelineStage): + """Extracts pages from documents and tracks them.""" + + def __init__( + self, + name: str, + page_tracker: Any, + max_concurrent_extractions: int, + ): + super().__init__(name) + self.page_tracker = page_tracker + self.semaphore = asyncio.Semaphore(max_concurrent_extractions) + self.input_channel = "in" + self.output_channel = "out" + self.failure_channel = "fail" + + async def _extract_page( + self, page_no: int, conv_res: ConversionResult + ) -> StreamItem | None: + """Coroutine to extract a single page.""" + try: + async with self.semaphore: + page = Page(page_no=page_no) + conv_res.pages.append(page) + + page._backend = await self.loop.run_in_executor( + None, conv_res.input._backend.load_page, page_no + ) + + if page._backend and page._backend.is_valid(): + page.size = page._backend.get_size() + await self.page_tracker.track_page_loaded(page, conv_res) + return StreamItem( + payload=page, conv_res_id=id(conv_res), conv_res=conv_res + ) + except Exception: + # In case of page-level error, we might log it but continue + # For now, we don't propagate failure here, but in the document level + pass + return None + + async def _process_document(self, in_doc: InputDocument): + """Processes a single document, extracting all its pages.""" + conv_res = ConversionResult(input=in_doc) + + try: + from docling.backend.pdf_backend import PdfDocumentBackend + + if not isinstance(in_doc._backend, PdfDocumentBackend): + raise TypeError("Backend is not a valid PdfDocumentBackend") + + total_pages = in_doc.page_count + await self.page_tracker.register_document(conv_res, total_pages) + + start_page, end_page = conv_res.input.limits.page_range + page_indices_to_extract = [ + i for i in range(total_pages) if (start_page - 1) <= i <= (end_page - 1) + ] + + tasks = [ + self.loop.create_task(self._extract_page(i, conv_res)) + for i in page_indices_to_extract + ] + pages_extracted = await asyncio.gather(*tasks) + + await self._send_to_outputs( + self.output_channel, [p for p in pages_extracted if p] + ) + + except Exception: + conv_res.status = "FAILURE" + await self._send_to_outputs(self.failure_channel, [conv_res]) + + async def run(self) -> None: + """Main loop to consume documents and launch extraction tasks.""" + q_in = self.input_queues[self.input_channel] + while True: + doc = await q_in.get() + if doc is STOP_SENTINEL: + await self._signal_downstream_completion() + break + await self._process_document(doc) + + +class PageProcessorStage(PipelineStage): + """Applies a synchronous, 1-to-1 processing function to each page.""" + + def __init__(self, name: str, model: Any): + super().__init__(name) + self.model = model + self.input_channel = "in" + self.output_channel = "out" + + async def run(self) -> None: + q_in = self.input_queues[self.input_channel] + while True: + item = await q_in.get() + if item is STOP_SENTINEL: + await self._signal_downstream_completion() + break + + # The model call is sync, run in thread to avoid blocking event loop + processed_page = await self.loop.run_in_executor( + None, lambda: next(iter(self.model(item.conv_res, [item.payload]))) + ) + item.payload = processed_page + await self._send_to_outputs(self.output_channel, [item]) + + +class BatchProcessorStage(PipelineStage): + """Batches items and applies a synchronous model to the batch.""" + + def __init__( + self, + name: str, + model: Any, + batch_size: int, + batch_timeout: float, + ): + super().__init__(name) + self.model = model + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.input_channel = "in" + self.output_channel = "out" + + async def _collect_batch(self, q_in: asyncio.Queue) -> List[StreamItem] | None: + """Collects a batch of items from the input queue with a timeout.""" + try: + # Wait for the first item without a timeout + first_item = await q_in.get() + if first_item is STOP_SENTINEL: + return None # End of stream + except asyncio.CancelledError: + return None + + batch = [first_item] + start_time = self.loop.time() + + while len(batch) < self.batch_size: + timeout = self.batch_timeout - (self.loop.time() - start_time) + if timeout <= 0: + break + try: + item = await asyncio.wait_for(q_in.get(), timeout) + if item is STOP_SENTINEL: + # Put sentinel back for other potential consumers or the main loop + await q_in.put(STOP_SENTINEL) + break + batch.append(item) + except asyncio.TimeoutError: + break # Batching timeout reached + return batch + + async def run(self) -> None: + q_in = self.input_queues[self.input_channel] + while True: + batch = await self._collect_batch(q_in) + + if not batch: # This can be None or an empty list. + await self._signal_downstream_completion() + break + + # Group pages by their original ConversionResult + grouped_by_doc = defaultdict(list) + for item in batch: + grouped_by_doc[item.conv_res_id].append(item) + + processed_items = [] + for conv_res_id, items in grouped_by_doc.items(): + conv_res = items[0].conv_res + pages = [item.payload for item in items] + + # The model call is sync, run in thread + processed_pages = await self.loop.run_in_executor( + None, lambda: list(self.model(conv_res, pages)) + ) + + # Re-wrap the processed pages into StreamItems + for i, page in enumerate(processed_pages): + processed_items.append( + StreamItem( + payload=page, + conv_res_id=items[i].conv_res_id, + conv_res=items[i].conv_res, + ) + ) + + await self._send_to_outputs(self.output_channel, processed_items) + + +class AggregationStage(PipelineStage): + """Aggregates processed pages back into completed documents.""" + + def __init__( + self, + name: str, + page_tracker: Any, + finalizer_func: Callable[[ConversionResult], Coroutine], + ): + super().__init__(name) + self.page_tracker = page_tracker + self.finalizer_func = finalizer_func + self.success_channel = "in" + self.failure_channel = "fail" + self.output_channel = "out" + + async def run(self) -> None: + success_q = self.input_queues[self.success_channel] + failure_q = self.input_queues.get(self.failure_channel) + doc_completers: Dict[int, asyncio.Future] = {} + + async def handle_successes(): + while True: + item = await success_q.get() + if item is STOP_SENTINEL: + break + is_doc_complete = await self.page_tracker.track_page_completion( + item.payload, item.conv_res + ) + if is_doc_complete: + await self.finalizer_func(item.conv_res) + await self._send_to_outputs(self.output_channel, [item.conv_res]) + if item.conv_res_id in doc_completers: + doc_completers.pop(item.conv_res_id).set_result(True) + + async def handle_failures(): + if failure_q is None: + return # No failure channel, nothing to do + while True: + failed_res = await failure_q.get() + if failed_res is STOP_SENTINEL: + break + await self._send_to_outputs(self.output_channel, [failed_res]) + if id(failed_res) in doc_completers: + doc_completers.pop(id(failed_res)).set_result(True) + + # Create tasks only for channels that exist + tasks = [handle_successes()] + if failure_q is not None: + tasks.append(handle_failures()) + + await asyncio.gather(*tasks) + await self._signal_downstream_completion()