mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 03:24:59 +00:00
Cleanups and safety improvements
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
0be9349884
commit
f98c7e21dd
@ -22,7 +22,7 @@ from docling.models.page_preprocessing_model import (
|
||||
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.pipeline.async_base_pipeline import AsyncPipeline
|
||||
from docling.pipeline.graph import GraphRunner
|
||||
from docling.pipeline.graph import GraphRunner, get_pipeline_thread_pool
|
||||
from docling.pipeline.resource_manager import AsyncPageTracker
|
||||
from docling.pipeline.stages import (
|
||||
AggregationStage,
|
||||
@ -49,6 +49,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
keep_images=self._should_keep_images(),
|
||||
keep_backend=self._should_keep_backend(),
|
||||
)
|
||||
# Get shared thread pool for enrichment operations
|
||||
self._thread_pool = get_pipeline_thread_pool()
|
||||
self._initialize_models()
|
||||
|
||||
def _should_keep_images(self) -> bool:
|
||||
@ -416,9 +418,10 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
for element_batch in chunkify(
|
||||
elements_to_process, model.elements_batch_size
|
||||
):
|
||||
# Run model in thread to avoid blocking
|
||||
await asyncio.to_thread(
|
||||
lambda: list(model(conv_res.document, element_batch))
|
||||
# Run model in shared thread pool to avoid blocking
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool,
|
||||
lambda: list(model(conv_res.document, element_batch)),
|
||||
)
|
||||
|
||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||
|
@ -1,14 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterable, Dict, List, Literal, Optional
|
||||
|
||||
# Sentinel to signal stream completion
|
||||
STOP_SENTINEL = object()
|
||||
|
||||
# Global thread pool for pipeline operations - shared across all stages
|
||||
_PIPELINE_THREAD_POOL: Optional[ThreadPoolExecutor] = None
|
||||
_THREAD_POOL_REFS = weakref.WeakSet()
|
||||
|
||||
|
||||
def get_pipeline_thread_pool(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
|
||||
"""Get or create the shared pipeline thread pool."""
|
||||
global _PIPELINE_THREAD_POOL
|
||||
if _PIPELINE_THREAD_POOL is None or _PIPELINE_THREAD_POOL._shutdown:
|
||||
_PIPELINE_THREAD_POOL = ThreadPoolExecutor(
|
||||
max_workers=max_workers, thread_name_prefix="docling_pipeline"
|
||||
)
|
||||
_THREAD_POOL_REFS.add(_PIPELINE_THREAD_POOL)
|
||||
return _PIPELINE_THREAD_POOL
|
||||
|
||||
|
||||
def shutdown_pipeline_thread_pool(wait: bool = True) -> None:
|
||||
"""Shutdown the shared thread pool."""
|
||||
global _PIPELINE_THREAD_POOL
|
||||
if _PIPELINE_THREAD_POOL is not None:
|
||||
_PIPELINE_THREAD_POOL.shutdown(wait=wait)
|
||||
_PIPELINE_THREAD_POOL = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StreamItem:
|
||||
@ -25,11 +50,12 @@ class StreamItem:
|
||||
class PipelineStage(ABC):
|
||||
"""A single, encapsulated step in a processing pipeline graph."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str, max_workers: Optional[int] = None):
|
||||
self.name = name
|
||||
self.input_queues: Dict[str, asyncio.Queue] = {}
|
||||
self.output_queues: Dict[str, List[asyncio.Queue]] = {}
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread_pool = get_pipeline_thread_pool(max_workers)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self) -> None:
|
||||
@ -58,6 +84,11 @@ class PipelineStage(ABC):
|
||||
self._loop = asyncio.get_running_loop()
|
||||
return self._loop
|
||||
|
||||
@property
|
||||
def thread_pool(self) -> ThreadPoolExecutor:
|
||||
"""Get the shared thread pool for this stage."""
|
||||
return self._thread_pool
|
||||
|
||||
|
||||
class GraphRunner:
|
||||
"""Connects stages and runs the pipeline graph."""
|
||||
@ -125,20 +156,25 @@ class GraphRunner:
|
||||
"""
|
||||
self._wire_graph(queue_max_size)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
# Create a task for the source feeder
|
||||
tg.create_task(
|
||||
self._run_source(
|
||||
source_stream, source_config["stage"], source_config["channel"]
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
# Create a task for the source feeder
|
||||
tg.create_task(
|
||||
self._run_source(
|
||||
source_stream, source_config["stage"], source_config["channel"]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Create tasks for all pipeline stages
|
||||
for stage in self._stages.values():
|
||||
tg.create_task(stage.run())
|
||||
# Create tasks for all pipeline stages
|
||||
for stage in self._stages.values():
|
||||
tg.create_task(stage.run())
|
||||
|
||||
# Yield results from the sink
|
||||
async for result in self._run_sink(
|
||||
sink_config["stage"], sink_config["channel"]
|
||||
):
|
||||
yield result
|
||||
# Yield results from the sink
|
||||
async for result in self._run_sink(
|
||||
sink_config["stage"], sink_config["channel"]
|
||||
):
|
||||
yield result
|
||||
finally:
|
||||
# Ensure thread pool cleanup on pipeline completion
|
||||
# Note: We don't shutdown here as other pipelines might be using it
|
||||
pass
|
||||
|
@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Set
|
||||
from typing import Dict, Optional
|
||||
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.pipeline.async_base_pipeline import DocumentTracker
|
||||
from docling.pipeline.graph import get_pipeline_thread_pool
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@ -19,22 +21,29 @@ class AsyncPageTracker:
|
||||
keep_images: bool = False
|
||||
keep_backend: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize shared thread pool reference after dataclass creation"""
|
||||
self._thread_pool = get_pipeline_thread_pool()
|
||||
|
||||
async def register_document(
|
||||
self, conv_res: ConversionResult, total_pages: int
|
||||
) -> str:
|
||||
"""Register a new document for tracking"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
# Use UUID for better collision resistance than str(id())
|
||||
doc_id = str(uuid.uuid4())
|
||||
self._doc_trackers[doc_id] = DocumentTracker(
|
||||
doc_id=doc_id, total_pages=total_pages, conv_result=conv_res
|
||||
)
|
||||
# Store the doc_id in the conv_res for later lookup
|
||||
conv_res._async_doc_id = doc_id
|
||||
return doc_id
|
||||
|
||||
async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None:
|
||||
"""Track when a page backend is loaded"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id in self._doc_trackers and page._backend is not None:
|
||||
doc_id = getattr(conv_res, "_async_doc_id", None)
|
||||
if doc_id and doc_id in self._doc_trackers and page._backend is not None:
|
||||
self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend
|
||||
|
||||
async def track_page_completion(
|
||||
@ -42,8 +51,8 @@ class AsyncPageTracker:
|
||||
) -> bool:
|
||||
"""Track page completion and cleanup when all pages done"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id not in self._doc_trackers:
|
||||
doc_id = getattr(conv_res, "_async_doc_id", None)
|
||||
if not doc_id or doc_id not in self._doc_trackers:
|
||||
_log.warning(f"Document {doc_id} not registered for tracking")
|
||||
return False
|
||||
|
||||
@ -58,6 +67,9 @@ class AsyncPageTracker:
|
||||
if tracker.processed_pages == tracker.total_pages:
|
||||
await self._cleanup_document_resources(tracker)
|
||||
del self._doc_trackers[doc_id]
|
||||
# Clean up the doc_id from conv_res
|
||||
if hasattr(conv_res, "_async_doc_id"):
|
||||
delattr(conv_res, "_async_doc_id")
|
||||
return True # Document is complete
|
||||
|
||||
return False # Document is not yet complete
|
||||
@ -69,8 +81,10 @@ class AsyncPageTracker:
|
||||
for page_no, backend in tracker.page_backends.items():
|
||||
if backend is not None:
|
||||
try:
|
||||
# Run unload in thread to avoid blocking
|
||||
await asyncio.to_thread(backend.unload)
|
||||
# Run unload in shared thread pool to avoid blocking
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, backend.unload
|
||||
)
|
||||
except Exception as e:
|
||||
_log.warning(
|
||||
f"Failed to unload backend for page {page_no}: {e}"
|
||||
@ -85,41 +99,3 @@ class AsyncPageTracker:
|
||||
for tracker in self._doc_trackers.values():
|
||||
await self._cleanup_document_resources(tracker)
|
||||
self._doc_trackers.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversionResultAccumulator:
|
||||
"""Accumulates updates to ConversionResult without immediate mutation"""
|
||||
|
||||
_updates: Dict[str, Dict] = field(default_factory=dict) # doc_id -> updates
|
||||
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
|
||||
async def accumulate_page_result(
|
||||
self, page_no: int, conv_res: ConversionResult, updates: Dict
|
||||
) -> None:
|
||||
"""Accumulate updates for later application"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id not in self._updates:
|
||||
self._updates[doc_id] = {}
|
||||
|
||||
if page_no not in self._updates[doc_id]:
|
||||
self._updates[doc_id][page_no] = {}
|
||||
|
||||
self._updates[doc_id][page_no].update(updates)
|
||||
|
||||
async def flush_to_conv_res(self, conv_res: ConversionResult) -> None:
|
||||
"""Apply all accumulated updates atomically"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id in self._updates:
|
||||
# Apply updates
|
||||
for page_no, updates in self._updates[doc_id].items():
|
||||
# Find the page and apply updates
|
||||
for page in conv_res.pages:
|
||||
if page.page_no == page_no:
|
||||
for key, value in updates.items():
|
||||
setattr(page, key, value)
|
||||
break
|
||||
|
||||
del self._updates[doc_id]
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
|
||||
@ -8,6 +9,8 @@ from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
|
||||
from docling.datamodel.document import ConversionResult, InputDocument, Page
|
||||
from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SourceStage(PipelineStage):
|
||||
"""A placeholder stage to represent the entry point of the graph."""
|
||||
@ -53,7 +56,7 @@ class ExtractionStage(PipelineStage):
|
||||
conv_res.pages.append(page)
|
||||
|
||||
page._backend = await self.loop.run_in_executor(
|
||||
None, conv_res.input._backend.load_page, page_no
|
||||
self.thread_pool, conv_res.input._backend.load_page, page_no
|
||||
)
|
||||
|
||||
if page._backend and page._backend.is_valid():
|
||||
@ -62,11 +65,18 @@ class ExtractionStage(PipelineStage):
|
||||
return StreamItem(
|
||||
payload=page, conv_res_id=id(conv_res), conv_res=conv_res
|
||||
)
|
||||
except Exception:
|
||||
# In case of page-level error, we might log it but continue
|
||||
# For now, we don't propagate failure here, but in the document level
|
||||
pass
|
||||
return None
|
||||
else:
|
||||
_log.warning(
|
||||
f"Failed to load or validate page {page_no} from document {conv_res.input.file.name}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
_log.error(
|
||||
f"Error extracting page {page_no} from document {conv_res.input.file.name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Don't propagate individual page failures - document-level error handling will catch this
|
||||
return None
|
||||
|
||||
async def _process_document(self, in_doc: InputDocument):
|
||||
"""Processes a single document, extracting all its pages."""
|
||||
@ -90,13 +100,34 @@ class ExtractionStage(PipelineStage):
|
||||
self.loop.create_task(self._extract_page(i, conv_res))
|
||||
for i in page_indices_to_extract
|
||||
]
|
||||
pages_extracted = await asyncio.gather(*tasks)
|
||||
pages_extracted = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
await self._send_to_outputs(
|
||||
self.output_channel, [p for p in pages_extracted if p]
|
||||
# Filter out None results and exceptions, log any exceptions found
|
||||
valid_pages = []
|
||||
for i, result in enumerate(pages_extracted):
|
||||
if isinstance(result, Exception):
|
||||
_log.error(
|
||||
f"Page extraction failed for page {page_indices_to_extract[i]} "
|
||||
f"in document {in_doc.file.name}: {result}"
|
||||
)
|
||||
elif result is not None:
|
||||
valid_pages.append(result)
|
||||
|
||||
await self._send_to_outputs(self.output_channel, valid_pages)
|
||||
|
||||
# If no pages were successfully extracted, mark as failure
|
||||
if not valid_pages:
|
||||
_log.error(
|
||||
f"No pages could be extracted from document {in_doc.file.name}"
|
||||
)
|
||||
conv_res.status = "FAILURE"
|
||||
await self._send_to_outputs(self.failure_channel, [conv_res])
|
||||
|
||||
except Exception as e:
|
||||
_log.error(
|
||||
f"Document-level extraction failed for {in_doc.file.name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
conv_res.status = "FAILURE"
|
||||
await self._send_to_outputs(self.failure_channel, [conv_res])
|
||||
|
||||
@ -130,7 +161,8 @@ class PageProcessorStage(PipelineStage):
|
||||
|
||||
# The model call is sync, run in thread to avoid blocking event loop
|
||||
processed_page = await self.loop.run_in_executor(
|
||||
None, lambda: next(iter(self.model(item.conv_res, [item.payload])))
|
||||
self.thread_pool,
|
||||
lambda: next(iter(self.model(item.conv_res, [item.payload]))),
|
||||
)
|
||||
item.payload = processed_page
|
||||
await self._send_to_outputs(self.output_channel, [item])
|
||||
@ -202,7 +234,7 @@ class BatchProcessorStage(PipelineStage):
|
||||
|
||||
# The model call is sync, run in thread
|
||||
processed_pages = await self.loop.run_in_executor(
|
||||
None, lambda: list(self.model(conv_res, pages))
|
||||
self.thread_pool, lambda: list(self.model(conv_res, pages))
|
||||
)
|
||||
|
||||
# Re-wrap the processed pages into StreamItems
|
||||
@ -237,7 +269,6 @@ class AggregationStage(PipelineStage):
|
||||
async def run(self) -> None:
|
||||
success_q = self.input_queues[self.success_channel]
|
||||
failure_q = self.input_queues.get(self.failure_channel)
|
||||
doc_completers: Dict[int, asyncio.Future] = {}
|
||||
|
||||
async def handle_successes():
|
||||
while True:
|
||||
@ -250,8 +281,6 @@ class AggregationStage(PipelineStage):
|
||||
if is_doc_complete:
|
||||
await self.finalizer_func(item.conv_res)
|
||||
await self._send_to_outputs(self.output_channel, [item.conv_res])
|
||||
if item.conv_res_id in doc_completers:
|
||||
doc_completers.pop(item.conv_res_id).set_result(True)
|
||||
|
||||
async def handle_failures():
|
||||
if failure_q is None:
|
||||
@ -261,8 +290,6 @@ class AggregationStage(PipelineStage):
|
||||
if failed_res is STOP_SENTINEL:
|
||||
break
|
||||
await self._send_to_outputs(self.output_channel, [failed_res])
|
||||
if id(failed_res) in doc_completers:
|
||||
doc_completers.pop(id(failed_res)).set_result(True)
|
||||
|
||||
# Create tasks only for channels that exist
|
||||
tasks = [handle_successes()]
|
||||
|
Loading…
Reference in New Issue
Block a user