From f98c7e21dd3cf0b966aee9028bc7c478c2ae875b Mon Sep 17 00:00:00 2001 From: Christoph Auer Date: Wed, 16 Jul 2025 10:46:32 +0200 Subject: [PATCH] Cleanups and safety improvements Signed-off-by: Christoph Auer --- .../pipeline/async_standard_pdf_pipeline.py | 11 +-- docling/pipeline/graph.py | 66 ++++++++++++++---- docling/pipeline/resource_manager.py | 68 ++++++------------- docling/pipeline/stages.py | 63 ++++++++++++----- 4 files changed, 125 insertions(+), 83 deletions(-) diff --git a/docling/pipeline/async_standard_pdf_pipeline.py b/docling/pipeline/async_standard_pdf_pipeline.py index 47c25a47..cf12ec0c 100644 --- a/docling/pipeline/async_standard_pdf_pipeline.py +++ b/docling/pipeline/async_standard_pdf_pipeline.py @@ -22,7 +22,7 @@ from docling.models.page_preprocessing_model import ( from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.table_structure_model import TableStructureModel from docling.pipeline.async_base_pipeline import AsyncPipeline -from docling.pipeline.graph import GraphRunner +from docling.pipeline.graph import GraphRunner, get_pipeline_thread_pool from docling.pipeline.resource_manager import AsyncPageTracker from docling.pipeline.stages import ( AggregationStage, @@ -49,6 +49,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline): keep_images=self._should_keep_images(), keep_backend=self._should_keep_backend(), ) + # Get shared thread pool for enrichment operations + self._thread_pool = get_pipeline_thread_pool() self._initialize_models() def _should_keep_images(self) -> bool: @@ -416,9 +418,10 @@ class AsyncStandardPdfPipeline(AsyncPipeline): for element_batch in chunkify( elements_to_process, model.elements_batch_size ): - # Run model in thread to avoid blocking - await asyncio.to_thread( - lambda: list(model(conv_res.document, element_batch)) + # Run model in shared thread pool to avoid blocking + await asyncio.get_running_loop().run_in_executor( + self._thread_pool, + lambda: list(model(conv_res.document, element_batch)), ) def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: diff --git a/docling/pipeline/graph.py b/docling/pipeline/graph.py index 0d9827a1..990b2178 100644 --- a/docling/pipeline/graph.py +++ b/docling/pipeline/graph.py @@ -1,14 +1,39 @@ from __future__ import annotations import asyncio +import weakref from abc import ABC, abstractmethod from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Any, AsyncIterable, Dict, List, Literal, Optional # Sentinel to signal stream completion STOP_SENTINEL = object() +# Global thread pool for pipeline operations - shared across all stages +_PIPELINE_THREAD_POOL: Optional[ThreadPoolExecutor] = None +_THREAD_POOL_REFS = weakref.WeakSet() + + +def get_pipeline_thread_pool(max_workers: Optional[int] = None) -> ThreadPoolExecutor: + """Get or create the shared pipeline thread pool.""" + global _PIPELINE_THREAD_POOL + if _PIPELINE_THREAD_POOL is None or _PIPELINE_THREAD_POOL._shutdown: + _PIPELINE_THREAD_POOL = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="docling_pipeline" + ) + _THREAD_POOL_REFS.add(_PIPELINE_THREAD_POOL) + return _PIPELINE_THREAD_POOL + + +def shutdown_pipeline_thread_pool(wait: bool = True) -> None: + """Shutdown the shared thread pool.""" + global _PIPELINE_THREAD_POOL + if _PIPELINE_THREAD_POOL is not None: + _PIPELINE_THREAD_POOL.shutdown(wait=wait) + _PIPELINE_THREAD_POOL = None + @dataclass(slots=True) class StreamItem: @@ -25,11 +50,12 @@ class StreamItem: class PipelineStage(ABC): """A single, encapsulated step in a processing pipeline graph.""" - def __init__(self, name: str): + def __init__(self, name: str, max_workers: Optional[int] = None): self.name = name self.input_queues: Dict[str, asyncio.Queue] = {} self.output_queues: Dict[str, List[asyncio.Queue]] = {} self._loop: Optional[asyncio.AbstractEventLoop] = None + self._thread_pool = get_pipeline_thread_pool(max_workers) @abstractmethod async def run(self) -> None: @@ -58,6 +84,11 @@ class PipelineStage(ABC): self._loop = asyncio.get_running_loop() return self._loop + @property + def thread_pool(self) -> ThreadPoolExecutor: + """Get the shared thread pool for this stage.""" + return self._thread_pool + class GraphRunner: """Connects stages and runs the pipeline graph.""" @@ -125,20 +156,25 @@ class GraphRunner: """ self._wire_graph(queue_max_size) - async with asyncio.TaskGroup() as tg: - # Create a task for the source feeder - tg.create_task( - self._run_source( - source_stream, source_config["stage"], source_config["channel"] + try: + async with asyncio.TaskGroup() as tg: + # Create a task for the source feeder + tg.create_task( + self._run_source( + source_stream, source_config["stage"], source_config["channel"] + ) ) - ) - # Create tasks for all pipeline stages - for stage in self._stages.values(): - tg.create_task(stage.run()) + # Create tasks for all pipeline stages + for stage in self._stages.values(): + tg.create_task(stage.run()) - # Yield results from the sink - async for result in self._run_sink( - sink_config["stage"], sink_config["channel"] - ): - yield result + # Yield results from the sink + async for result in self._run_sink( + sink_config["stage"], sink_config["channel"] + ): + yield result + finally: + # Ensure thread pool cleanup on pipeline completion + # Note: We don't shutdown here as other pipelines might be using it + pass diff --git a/docling/pipeline/resource_manager.py b/docling/pipeline/resource_manager.py index 4e1d5e2f..230fb4c7 100644 --- a/docling/pipeline/resource_manager.py +++ b/docling/pipeline/resource_manager.py @@ -1,11 +1,13 @@ import asyncio import logging +import uuid from dataclasses import dataclass, field -from typing import Dict, Optional, Set +from typing import Dict, Optional from docling.datamodel.base_models import Page from docling.datamodel.document import ConversionResult from docling.pipeline.async_base_pipeline import DocumentTracker +from docling.pipeline.graph import get_pipeline_thread_pool _log = logging.getLogger(__name__) @@ -19,22 +21,29 @@ class AsyncPageTracker: keep_images: bool = False keep_backend: bool = False + def __post_init__(self): + """Initialize shared thread pool reference after dataclass creation""" + self._thread_pool = get_pipeline_thread_pool() + async def register_document( self, conv_res: ConversionResult, total_pages: int ) -> str: """Register a new document for tracking""" async with self._lock: - doc_id = str(id(conv_res)) + # Use UUID for better collision resistance than str(id()) + doc_id = str(uuid.uuid4()) self._doc_trackers[doc_id] = DocumentTracker( doc_id=doc_id, total_pages=total_pages, conv_result=conv_res ) + # Store the doc_id in the conv_res for later lookup + conv_res._async_doc_id = doc_id return doc_id async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None: """Track when a page backend is loaded""" async with self._lock: - doc_id = str(id(conv_res)) - if doc_id in self._doc_trackers and page._backend is not None: + doc_id = getattr(conv_res, "_async_doc_id", None) + if doc_id and doc_id in self._doc_trackers and page._backend is not None: self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend async def track_page_completion( @@ -42,8 +51,8 @@ class AsyncPageTracker: ) -> bool: """Track page completion and cleanup when all pages done""" async with self._lock: - doc_id = str(id(conv_res)) - if doc_id not in self._doc_trackers: + doc_id = getattr(conv_res, "_async_doc_id", None) + if not doc_id or doc_id not in self._doc_trackers: _log.warning(f"Document {doc_id} not registered for tracking") return False @@ -58,6 +67,9 @@ class AsyncPageTracker: if tracker.processed_pages == tracker.total_pages: await self._cleanup_document_resources(tracker) del self._doc_trackers[doc_id] + # Clean up the doc_id from conv_res + if hasattr(conv_res, "_async_doc_id"): + delattr(conv_res, "_async_doc_id") return True # Document is complete return False # Document is not yet complete @@ -69,8 +81,10 @@ class AsyncPageTracker: for page_no, backend in tracker.page_backends.items(): if backend is not None: try: - # Run unload in thread to avoid blocking - await asyncio.to_thread(backend.unload) + # Run unload in shared thread pool to avoid blocking + await asyncio.get_running_loop().run_in_executor( + self._thread_pool, backend.unload + ) except Exception as e: _log.warning( f"Failed to unload backend for page {page_no}: {e}" @@ -85,41 +99,3 @@ class AsyncPageTracker: for tracker in self._doc_trackers.values(): await self._cleanup_document_resources(tracker) self._doc_trackers.clear() - - -@dataclass -class ConversionResultAccumulator: - """Accumulates updates to ConversionResult without immediate mutation""" - - _updates: Dict[str, Dict] = field(default_factory=dict) # doc_id -> updates - _lock: asyncio.Lock = field(default_factory=asyncio.Lock) - - async def accumulate_page_result( - self, page_no: int, conv_res: ConversionResult, updates: Dict - ) -> None: - """Accumulate updates for later application""" - async with self._lock: - doc_id = str(id(conv_res)) - if doc_id not in self._updates: - self._updates[doc_id] = {} - - if page_no not in self._updates[doc_id]: - self._updates[doc_id][page_no] = {} - - self._updates[doc_id][page_no].update(updates) - - async def flush_to_conv_res(self, conv_res: ConversionResult) -> None: - """Apply all accumulated updates atomically""" - async with self._lock: - doc_id = str(id(conv_res)) - if doc_id in self._updates: - # Apply updates - for page_no, updates in self._updates[doc_id].items(): - # Find the page and apply updates - for page in conv_res.pages: - if page.page_no == page_no: - for key, value in updates.items(): - setattr(page, key, value) - break - - del self._updates[doc_id] diff --git a/docling/pipeline/stages.py b/docling/pipeline/stages.py index b07ce33a..010dc37c 100644 --- a/docling/pipeline/stages.py +++ b/docling/pipeline/stages.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import time from collections import defaultdict from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List @@ -8,6 +9,8 @@ from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List from docling.datamodel.document import ConversionResult, InputDocument, Page from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem +_log = logging.getLogger(__name__) + class SourceStage(PipelineStage): """A placeholder stage to represent the entry point of the graph.""" @@ -53,7 +56,7 @@ class ExtractionStage(PipelineStage): conv_res.pages.append(page) page._backend = await self.loop.run_in_executor( - None, conv_res.input._backend.load_page, page_no + self.thread_pool, conv_res.input._backend.load_page, page_no ) if page._backend and page._backend.is_valid(): @@ -62,11 +65,18 @@ class ExtractionStage(PipelineStage): return StreamItem( payload=page, conv_res_id=id(conv_res), conv_res=conv_res ) - except Exception: - # In case of page-level error, we might log it but continue - # For now, we don't propagate failure here, but in the document level - pass - return None + else: + _log.warning( + f"Failed to load or validate page {page_no} from document {conv_res.input.file.name}" + ) + return None + except Exception as e: + _log.error( + f"Error extracting page {page_no} from document {conv_res.input.file.name}: {e}", + exc_info=True, + ) + # Don't propagate individual page failures - document-level error handling will catch this + return None async def _process_document(self, in_doc: InputDocument): """Processes a single document, extracting all its pages.""" @@ -90,13 +100,34 @@ class ExtractionStage(PipelineStage): self.loop.create_task(self._extract_page(i, conv_res)) for i in page_indices_to_extract ] - pages_extracted = await asyncio.gather(*tasks) + pages_extracted = await asyncio.gather(*tasks, return_exceptions=True) - await self._send_to_outputs( - self.output_channel, [p for p in pages_extracted if p] + # Filter out None results and exceptions, log any exceptions found + valid_pages = [] + for i, result in enumerate(pages_extracted): + if isinstance(result, Exception): + _log.error( + f"Page extraction failed for page {page_indices_to_extract[i]} " + f"in document {in_doc.file.name}: {result}" + ) + elif result is not None: + valid_pages.append(result) + + await self._send_to_outputs(self.output_channel, valid_pages) + + # If no pages were successfully extracted, mark as failure + if not valid_pages: + _log.error( + f"No pages could be extracted from document {in_doc.file.name}" + ) + conv_res.status = "FAILURE" + await self._send_to_outputs(self.failure_channel, [conv_res]) + + except Exception as e: + _log.error( + f"Document-level extraction failed for {in_doc.file.name}: {e}", + exc_info=True, ) - - except Exception: conv_res.status = "FAILURE" await self._send_to_outputs(self.failure_channel, [conv_res]) @@ -130,7 +161,8 @@ class PageProcessorStage(PipelineStage): # The model call is sync, run in thread to avoid blocking event loop processed_page = await self.loop.run_in_executor( - None, lambda: next(iter(self.model(item.conv_res, [item.payload]))) + self.thread_pool, + lambda: next(iter(self.model(item.conv_res, [item.payload]))), ) item.payload = processed_page await self._send_to_outputs(self.output_channel, [item]) @@ -202,7 +234,7 @@ class BatchProcessorStage(PipelineStage): # The model call is sync, run in thread processed_pages = await self.loop.run_in_executor( - None, lambda: list(self.model(conv_res, pages)) + self.thread_pool, lambda: list(self.model(conv_res, pages)) ) # Re-wrap the processed pages into StreamItems @@ -237,7 +269,6 @@ class AggregationStage(PipelineStage): async def run(self) -> None: success_q = self.input_queues[self.success_channel] failure_q = self.input_queues.get(self.failure_channel) - doc_completers: Dict[int, asyncio.Future] = {} async def handle_successes(): while True: @@ -250,8 +281,6 @@ class AggregationStage(PipelineStage): if is_doc_complete: await self.finalizer_func(item.conv_res) await self._send_to_outputs(self.output_channel, [item.conv_res]) - if item.conv_res_id in doc_completers: - doc_completers.pop(item.conv_res_id).set_result(True) async def handle_failures(): if failure_q is None: @@ -261,8 +290,6 @@ class AggregationStage(PipelineStage): if failed_res is STOP_SENTINEL: break await self._send_to_outputs(self.output_channel, [failed_res]) - if id(failed_res) in doc_completers: - doc_completers.pop(id(failed_res)).set_result(True) # Create tasks only for channels that exist tasks = [handle_successes()]