Cleanups and safety improvements

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-07-16 10:46:32 +02:00
parent 0be9349884
commit f98c7e21dd
4 changed files with 125 additions and 83 deletions

View File

@ -22,7 +22,7 @@ from docling.models.page_preprocessing_model import (
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.async_base_pipeline import AsyncPipeline
from docling.pipeline.graph import GraphRunner
from docling.pipeline.graph import GraphRunner, get_pipeline_thread_pool
from docling.pipeline.resource_manager import AsyncPageTracker
from docling.pipeline.stages import (
AggregationStage,
@ -49,6 +49,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
keep_images=self._should_keep_images(),
keep_backend=self._should_keep_backend(),
)
# Get shared thread pool for enrichment operations
self._thread_pool = get_pipeline_thread_pool()
self._initialize_models()
def _should_keep_images(self) -> bool:
@ -416,9 +418,10 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
for element_batch in chunkify(
elements_to_process, model.elements_batch_size
):
# Run model in thread to avoid blocking
await asyncio.to_thread(
lambda: list(model(conv_res.document, element_batch))
# Run model in shared thread pool to avoid blocking
await asyncio.get_running_loop().run_in_executor(
self._thread_pool,
lambda: list(model(conv_res.document, element_batch)),
)
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:

View File

@ -1,14 +1,39 @@
from __future__ import annotations
import asyncio
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, AsyncIterable, Dict, List, Literal, Optional
# Sentinel to signal stream completion
STOP_SENTINEL = object()
# Global thread pool for pipeline operations - shared across all stages
_PIPELINE_THREAD_POOL: Optional[ThreadPoolExecutor] = None
_THREAD_POOL_REFS = weakref.WeakSet()
def get_pipeline_thread_pool(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
"""Get or create the shared pipeline thread pool."""
global _PIPELINE_THREAD_POOL
if _PIPELINE_THREAD_POOL is None or _PIPELINE_THREAD_POOL._shutdown:
_PIPELINE_THREAD_POOL = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix="docling_pipeline"
)
_THREAD_POOL_REFS.add(_PIPELINE_THREAD_POOL)
return _PIPELINE_THREAD_POOL
def shutdown_pipeline_thread_pool(wait: bool = True) -> None:
"""Shutdown the shared thread pool."""
global _PIPELINE_THREAD_POOL
if _PIPELINE_THREAD_POOL is not None:
_PIPELINE_THREAD_POOL.shutdown(wait=wait)
_PIPELINE_THREAD_POOL = None
@dataclass(slots=True)
class StreamItem:
@ -25,11 +50,12 @@ class StreamItem:
class PipelineStage(ABC):
"""A single, encapsulated step in a processing pipeline graph."""
def __init__(self, name: str):
def __init__(self, name: str, max_workers: Optional[int] = None):
self.name = name
self.input_queues: Dict[str, asyncio.Queue] = {}
self.output_queues: Dict[str, List[asyncio.Queue]] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread_pool = get_pipeline_thread_pool(max_workers)
@abstractmethod
async def run(self) -> None:
@ -58,6 +84,11 @@ class PipelineStage(ABC):
self._loop = asyncio.get_running_loop()
return self._loop
@property
def thread_pool(self) -> ThreadPoolExecutor:
"""Get the shared thread pool for this stage."""
return self._thread_pool
class GraphRunner:
"""Connects stages and runs the pipeline graph."""
@ -125,20 +156,25 @@ class GraphRunner:
"""
self._wire_graph(queue_max_size)
async with asyncio.TaskGroup() as tg:
# Create a task for the source feeder
tg.create_task(
self._run_source(
source_stream, source_config["stage"], source_config["channel"]
try:
async with asyncio.TaskGroup() as tg:
# Create a task for the source feeder
tg.create_task(
self._run_source(
source_stream, source_config["stage"], source_config["channel"]
)
)
)
# Create tasks for all pipeline stages
for stage in self._stages.values():
tg.create_task(stage.run())
# Create tasks for all pipeline stages
for stage in self._stages.values():
tg.create_task(stage.run())
# Yield results from the sink
async for result in self._run_sink(
sink_config["stage"], sink_config["channel"]
):
yield result
# Yield results from the sink
async for result in self._run_sink(
sink_config["stage"], sink_config["channel"]
):
yield result
finally:
# Ensure thread pool cleanup on pipeline completion
# Note: We don't shutdown here as other pipelines might be using it
pass

View File

@ -1,11 +1,13 @@
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from typing import Dict, Optional, Set
from typing import Dict, Optional
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.pipeline.async_base_pipeline import DocumentTracker
from docling.pipeline.graph import get_pipeline_thread_pool
_log = logging.getLogger(__name__)
@ -19,22 +21,29 @@ class AsyncPageTracker:
keep_images: bool = False
keep_backend: bool = False
def __post_init__(self):
"""Initialize shared thread pool reference after dataclass creation"""
self._thread_pool = get_pipeline_thread_pool()
async def register_document(
self, conv_res: ConversionResult, total_pages: int
) -> str:
"""Register a new document for tracking"""
async with self._lock:
doc_id = str(id(conv_res))
# Use UUID for better collision resistance than str(id())
doc_id = str(uuid.uuid4())
self._doc_trackers[doc_id] = DocumentTracker(
doc_id=doc_id, total_pages=total_pages, conv_result=conv_res
)
# Store the doc_id in the conv_res for later lookup
conv_res._async_doc_id = doc_id
return doc_id
async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None:
"""Track when a page backend is loaded"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id in self._doc_trackers and page._backend is not None:
doc_id = getattr(conv_res, "_async_doc_id", None)
if doc_id and doc_id in self._doc_trackers and page._backend is not None:
self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend
async def track_page_completion(
@ -42,8 +51,8 @@ class AsyncPageTracker:
) -> bool:
"""Track page completion and cleanup when all pages done"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id not in self._doc_trackers:
doc_id = getattr(conv_res, "_async_doc_id", None)
if not doc_id or doc_id not in self._doc_trackers:
_log.warning(f"Document {doc_id} not registered for tracking")
return False
@ -58,6 +67,9 @@ class AsyncPageTracker:
if tracker.processed_pages == tracker.total_pages:
await self._cleanup_document_resources(tracker)
del self._doc_trackers[doc_id]
# Clean up the doc_id from conv_res
if hasattr(conv_res, "_async_doc_id"):
delattr(conv_res, "_async_doc_id")
return True # Document is complete
return False # Document is not yet complete
@ -69,8 +81,10 @@ class AsyncPageTracker:
for page_no, backend in tracker.page_backends.items():
if backend is not None:
try:
# Run unload in thread to avoid blocking
await asyncio.to_thread(backend.unload)
# Run unload in shared thread pool to avoid blocking
await asyncio.get_running_loop().run_in_executor(
self._thread_pool, backend.unload
)
except Exception as e:
_log.warning(
f"Failed to unload backend for page {page_no}: {e}"
@ -85,41 +99,3 @@ class AsyncPageTracker:
for tracker in self._doc_trackers.values():
await self._cleanup_document_resources(tracker)
self._doc_trackers.clear()
@dataclass
class ConversionResultAccumulator:
"""Accumulates updates to ConversionResult without immediate mutation"""
_updates: Dict[str, Dict] = field(default_factory=dict) # doc_id -> updates
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
async def accumulate_page_result(
self, page_no: int, conv_res: ConversionResult, updates: Dict
) -> None:
"""Accumulate updates for later application"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id not in self._updates:
self._updates[doc_id] = {}
if page_no not in self._updates[doc_id]:
self._updates[doc_id][page_no] = {}
self._updates[doc_id][page_no].update(updates)
async def flush_to_conv_res(self, conv_res: ConversionResult) -> None:
"""Apply all accumulated updates atomically"""
async with self._lock:
doc_id = str(id(conv_res))
if doc_id in self._updates:
# Apply updates
for page_no, updates in self._updates[doc_id].items():
# Find the page and apply updates
for page in conv_res.pages:
if page.page_no == page_no:
for key, value in updates.items():
setattr(page, key, value)
break
del self._updates[doc_id]

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import logging
import time
from collections import defaultdict
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
@ -8,6 +9,8 @@ from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
from docling.datamodel.document import ConversionResult, InputDocument, Page
from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem
_log = logging.getLogger(__name__)
class SourceStage(PipelineStage):
"""A placeholder stage to represent the entry point of the graph."""
@ -53,7 +56,7 @@ class ExtractionStage(PipelineStage):
conv_res.pages.append(page)
page._backend = await self.loop.run_in_executor(
None, conv_res.input._backend.load_page, page_no
self.thread_pool, conv_res.input._backend.load_page, page_no
)
if page._backend and page._backend.is_valid():
@ -62,11 +65,18 @@ class ExtractionStage(PipelineStage):
return StreamItem(
payload=page, conv_res_id=id(conv_res), conv_res=conv_res
)
except Exception:
# In case of page-level error, we might log it but continue
# For now, we don't propagate failure here, but in the document level
pass
return None
else:
_log.warning(
f"Failed to load or validate page {page_no} from document {conv_res.input.file.name}"
)
return None
except Exception as e:
_log.error(
f"Error extracting page {page_no} from document {conv_res.input.file.name}: {e}",
exc_info=True,
)
# Don't propagate individual page failures - document-level error handling will catch this
return None
async def _process_document(self, in_doc: InputDocument):
"""Processes a single document, extracting all its pages."""
@ -90,13 +100,34 @@ class ExtractionStage(PipelineStage):
self.loop.create_task(self._extract_page(i, conv_res))
for i in page_indices_to_extract
]
pages_extracted = await asyncio.gather(*tasks)
pages_extracted = await asyncio.gather(*tasks, return_exceptions=True)
await self._send_to_outputs(
self.output_channel, [p for p in pages_extracted if p]
# Filter out None results and exceptions, log any exceptions found
valid_pages = []
for i, result in enumerate(pages_extracted):
if isinstance(result, Exception):
_log.error(
f"Page extraction failed for page {page_indices_to_extract[i]} "
f"in document {in_doc.file.name}: {result}"
)
elif result is not None:
valid_pages.append(result)
await self._send_to_outputs(self.output_channel, valid_pages)
# If no pages were successfully extracted, mark as failure
if not valid_pages:
_log.error(
f"No pages could be extracted from document {in_doc.file.name}"
)
conv_res.status = "FAILURE"
await self._send_to_outputs(self.failure_channel, [conv_res])
except Exception as e:
_log.error(
f"Document-level extraction failed for {in_doc.file.name}: {e}",
exc_info=True,
)
except Exception:
conv_res.status = "FAILURE"
await self._send_to_outputs(self.failure_channel, [conv_res])
@ -130,7 +161,8 @@ class PageProcessorStage(PipelineStage):
# The model call is sync, run in thread to avoid blocking event loop
processed_page = await self.loop.run_in_executor(
None, lambda: next(iter(self.model(item.conv_res, [item.payload])))
self.thread_pool,
lambda: next(iter(self.model(item.conv_res, [item.payload]))),
)
item.payload = processed_page
await self._send_to_outputs(self.output_channel, [item])
@ -202,7 +234,7 @@ class BatchProcessorStage(PipelineStage):
# The model call is sync, run in thread
processed_pages = await self.loop.run_in_executor(
None, lambda: list(self.model(conv_res, pages))
self.thread_pool, lambda: list(self.model(conv_res, pages))
)
# Re-wrap the processed pages into StreamItems
@ -237,7 +269,6 @@ class AggregationStage(PipelineStage):
async def run(self) -> None:
success_q = self.input_queues[self.success_channel]
failure_q = self.input_queues.get(self.failure_channel)
doc_completers: Dict[int, asyncio.Future] = {}
async def handle_successes():
while True:
@ -250,8 +281,6 @@ class AggregationStage(PipelineStage):
if is_doc_complete:
await self.finalizer_func(item.conv_res)
await self._send_to_outputs(self.output_channel, [item.conv_res])
if item.conv_res_id in doc_completers:
doc_completers.pop(item.conv_res_id).set_result(True)
async def handle_failures():
if failure_q is None:
@ -261,8 +290,6 @@ class AggregationStage(PipelineStage):
if failed_res is STOP_SENTINEL:
break
await self._send_to_outputs(self.output_channel, [failed_res])
if id(failed_res) in doc_completers:
doc_completers.pop(id(failed_res)).set_result(True)
# Create tasks only for channels that exist
tasks = [handle_successes()]