mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-24 19:14:23 +00:00
Refactoring into async pipeline primitives and graph
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
ef25d03bc8
commit
0be9349884
@ -1,11 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterable, Dict, List, Optional, Tuple
|
||||
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import ConversionStatus, Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import AsyncPdfPipelineOptions
|
||||
@ -17,8 +13,6 @@ from docling.models.document_picture_classifier import (
|
||||
DocumentPictureClassifierOptions,
|
||||
)
|
||||
from docling.models.factories import get_ocr_factory, get_picture_description_factory
|
||||
|
||||
# Import the same models used by StandardPdfPipeline
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
|
||||
from docling.models.page_preprocessing_model import (
|
||||
@ -28,82 +22,36 @@ from docling.models.page_preprocessing_model import (
|
||||
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.pipeline.async_base_pipeline import AsyncPipeline
|
||||
from docling.pipeline.resource_manager import (
|
||||
AsyncPageTracker,
|
||||
ConversionResultAccumulator,
|
||||
from docling.pipeline.graph import GraphRunner
|
||||
from docling.pipeline.resource_manager import AsyncPageTracker
|
||||
from docling.pipeline.stages import (
|
||||
AggregationStage,
|
||||
BatchProcessorStage,
|
||||
ExtractionStage,
|
||||
PageProcessorStage,
|
||||
SinkStage,
|
||||
SourceStage,
|
||||
)
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageBatch:
|
||||
"""Represents a batch of pages to process through models"""
|
||||
|
||||
pages: List[Page] = field(default_factory=list)
|
||||
conv_results: List[ConversionResult] = field(default_factory=list)
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueTerminator:
|
||||
"""Sentinel value for proper queue termination tracking"""
|
||||
|
||||
stage: str
|
||||
error: Optional[Exception] = None
|
||||
|
||||
|
||||
class UpstreamAwareQueue(asyncio.Queue):
|
||||
"""Queue that tracks upstream completion to optimize batch processing"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.upstream_complete = asyncio.Event()
|
||||
self.items_in_flight = 0
|
||||
|
||||
async def put(self, item):
|
||||
"""Track items and upstream completion signals"""
|
||||
if isinstance(item, QueueTerminator):
|
||||
self.upstream_complete.set()
|
||||
else:
|
||||
self.items_in_flight += 1
|
||||
await super().put(item)
|
||||
|
||||
async def get(self):
|
||||
"""Track when items are consumed"""
|
||||
item = await super().get()
|
||||
if not isinstance(item, QueueTerminator):
|
||||
self.items_in_flight -= 1
|
||||
return item
|
||||
|
||||
def should_stop_waiting_for_batch(self) -> bool:
|
||||
"""Determine if we should stop waiting and process current batch"""
|
||||
return (
|
||||
self.upstream_complete.is_set()
|
||||
and self.items_in_flight == 0
|
||||
and self.empty()
|
||||
)
|
||||
|
||||
|
||||
class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
"""Async pipeline implementation with cross-document batching using structured concurrency"""
|
||||
"""
|
||||
An async, graph-based pipeline for processing PDFs with cross-document batching.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline_options: AsyncPdfPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.pipeline_options: AsyncPdfPipelineOptions = pipeline_options
|
||||
|
||||
# Resource management
|
||||
self.page_tracker = AsyncPageTracker(
|
||||
keep_images=self._should_keep_images(),
|
||||
keep_backend=self._should_keep_backend(),
|
||||
)
|
||||
|
||||
# Initialize models (same as StandardPdfPipeline)
|
||||
self._initialize_models()
|
||||
|
||||
def _should_keep_images(self) -> bool:
|
||||
"""Determine if images should be kept (same logic as StandardPdfPipeline)"""
|
||||
return (
|
||||
self.pipeline_options.generate_page_images
|
||||
or self.pipeline_options.generate_picture_images
|
||||
@ -111,7 +59,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
)
|
||||
|
||||
def _should_keep_backend(self) -> bool:
|
||||
"""Determine if backend should be kept"""
|
||||
return (
|
||||
self.pipeline_options.do_formula_enrichment
|
||||
or self.pipeline_options.do_code_enrichment
|
||||
@ -120,36 +67,26 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
)
|
||||
|
||||
def _initialize_models(self):
|
||||
"""Initialize all models (matching StandardPdfPipeline)"""
|
||||
artifacts_path = self._get_artifacts_path()
|
||||
|
||||
self.reading_order_model = ReadingOrderModel(options=ReadingOrderOptions())
|
||||
|
||||
# Build pipeline stages
|
||||
self.preprocessing_model = PagePreprocessingModel(
|
||||
options=PagePreprocessingOptions(
|
||||
images_scale=self.pipeline_options.images_scale,
|
||||
)
|
||||
)
|
||||
|
||||
self.ocr_model = self._get_ocr_model(artifacts_path)
|
||||
|
||||
self.layout_model = LayoutModel(
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
options=self.pipeline_options.layout_options,
|
||||
)
|
||||
|
||||
self.table_model = TableStructureModel(
|
||||
enabled=self.pipeline_options.do_table_structure,
|
||||
artifacts_path=artifacts_path,
|
||||
options=self.pipeline_options.table_structure_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
self.assemble_model = PageAssembleModel(options=PageAssembleOptions())
|
||||
|
||||
# Enrichment models
|
||||
self.code_formula_model = CodeFormulaModel(
|
||||
enabled=self.pipeline_options.do_code_enrichment
|
||||
or self.pipeline_options.do_formula_enrichment,
|
||||
@ -160,20 +97,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
),
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
self.picture_classifier = DocumentPictureClassifier(
|
||||
enabled=self.pipeline_options.do_picture_classification,
|
||||
artifacts_path=artifacts_path,
|
||||
options=DocumentPictureClassifierOptions(),
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
self.picture_description_model = self._get_picture_description_model(
|
||||
artifacts_path
|
||||
)
|
||||
|
||||
def _get_artifacts_path(self) -> Optional[str]:
|
||||
"""Get artifacts path (same as StandardPdfPipeline)"""
|
||||
from pathlib import Path
|
||||
|
||||
artifacts_path = None
|
||||
@ -190,7 +124,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
return artifacts_path
|
||||
|
||||
def _get_ocr_model(self, artifacts_path: Optional[str] = None) -> BaseOcrModel:
|
||||
"""Get OCR model (same as StandardPdfPipeline)"""
|
||||
factory = get_ocr_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
@ -202,7 +135,6 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
)
|
||||
|
||||
def _get_picture_description_model(self, artifacts_path: Optional[str] = None):
|
||||
"""Get picture description model (same as StandardPdfPipeline)"""
|
||||
factory = get_picture_description_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
@ -217,607 +149,116 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
async def execute_stream(
|
||||
self, input_docs: AsyncIterable[InputDocument]
|
||||
) -> AsyncIterable[ConversionResult]:
|
||||
"""Main async processing with structured concurrency and proper exception handling"""
|
||||
# Create queues for pipeline stages
|
||||
page_queue = UpstreamAwareQueue(
|
||||
maxsize=self.pipeline_options.extraction_queue_size
|
||||
)
|
||||
completed_queue = UpstreamAwareQueue()
|
||||
completed_docs = UpstreamAwareQueue()
|
||||
"""Main async processing driven by a pipeline graph."""
|
||||
stages = [
|
||||
SourceStage("source"),
|
||||
ExtractionStage(
|
||||
"extractor",
|
||||
self.page_tracker,
|
||||
self.pipeline_options.max_concurrent_extractions,
|
||||
),
|
||||
PageProcessorStage("preprocessor", self.preprocessing_model),
|
||||
BatchProcessorStage(
|
||||
"ocr",
|
||||
self.ocr_model,
|
||||
self.pipeline_options.ocr_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
),
|
||||
BatchProcessorStage(
|
||||
"layout",
|
||||
self.layout_model,
|
||||
self.pipeline_options.layout_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
),
|
||||
BatchProcessorStage(
|
||||
"table",
|
||||
self.table_model,
|
||||
self.pipeline_options.table_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
),
|
||||
PageProcessorStage("assembler", self.assemble_model),
|
||||
AggregationStage("aggregator", self.page_tracker, self._finalize_document),
|
||||
SinkStage("sink"),
|
||||
]
|
||||
|
||||
# Track active documents for proper termination
|
||||
doc_tracker = {"active_docs": 0, "extraction_done": False}
|
||||
doc_lock = asyncio.Lock()
|
||||
edges = [
|
||||
# Main processing path
|
||||
{
|
||||
"from_stage": "source",
|
||||
"from_output": "out",
|
||||
"to_stage": "extractor",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "extractor",
|
||||
"from_output": "out",
|
||||
"to_stage": "preprocessor",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "preprocessor",
|
||||
"from_output": "out",
|
||||
"to_stage": "ocr",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "ocr",
|
||||
"from_output": "out",
|
||||
"to_stage": "layout",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "layout",
|
||||
"from_output": "out",
|
||||
"to_stage": "table",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "table",
|
||||
"from_output": "out",
|
||||
"to_stage": "assembler",
|
||||
"to_input": "in",
|
||||
},
|
||||
{
|
||||
"from_stage": "assembler",
|
||||
"from_output": "out",
|
||||
"to_stage": "aggregator",
|
||||
"to_input": "in",
|
||||
},
|
||||
# Failure path
|
||||
{
|
||||
"from_stage": "extractor",
|
||||
"from_output": "fail",
|
||||
"to_stage": "aggregator",
|
||||
"to_input": "fail",
|
||||
},
|
||||
# Final output
|
||||
{
|
||||
"from_stage": "aggregator",
|
||||
"from_output": "out",
|
||||
"to_stage": "sink",
|
||||
"to_input": "in",
|
||||
},
|
||||
]
|
||||
|
||||
# Create exception event for coordinated shutdown
|
||||
exception_event = asyncio.Event()
|
||||
|
||||
async def track_document_start():
|
||||
async with doc_lock:
|
||||
doc_tracker["active_docs"] += 1
|
||||
|
||||
async def track_document_complete():
|
||||
async with doc_lock:
|
||||
doc_tracker["active_docs"] -= 1
|
||||
if doc_tracker["extraction_done"] and doc_tracker["active_docs"] == 0:
|
||||
# All documents completed
|
||||
await completed_docs.put(None)
|
||||
runner = GraphRunner(stages, edges)
|
||||
source_config = {"stage": "source", "channel": "out"}
|
||||
sink_config = {"stage": "sink", "channel": "in"}
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
# Start all tasks
|
||||
tg.create_task(
|
||||
self._extract_documents_wrapper(
|
||||
input_docs,
|
||||
page_queue,
|
||||
track_document_start,
|
||||
exception_event,
|
||||
doc_tracker,
|
||||
doc_lock,
|
||||
)
|
||||
)
|
||||
tg.create_task(
|
||||
self._process_pages_wrapper(
|
||||
page_queue, completed_queue, exception_event
|
||||
)
|
||||
)
|
||||
tg.create_task(
|
||||
self._aggregate_results_wrapper(
|
||||
completed_queue,
|
||||
completed_docs,
|
||||
track_document_complete,
|
||||
exception_event,
|
||||
)
|
||||
)
|
||||
|
||||
# Yield results as they complete
|
||||
async for result in self._yield_results(
|
||||
completed_docs, exception_event
|
||||
):
|
||||
yield result
|
||||
|
||||
async for result in runner.run(
|
||||
input_docs,
|
||||
source_config,
|
||||
sink_config,
|
||||
self.pipeline_options.extraction_queue_size,
|
||||
):
|
||||
yield result
|
||||
except* Exception as eg:
|
||||
# Handle exception group from TaskGroup
|
||||
_log.error(f"Pipeline failed with exceptions: {eg.exceptions}")
|
||||
# Re-raise the first exception
|
||||
raise (eg.exceptions[0] if eg.exceptions else RuntimeError("Unknown error"))
|
||||
finally:
|
||||
# Ensure cleanup
|
||||
await self.page_tracker.cleanup_all()
|
||||
|
||||
async def _extract_documents_wrapper(
|
||||
self,
|
||||
input_docs: AsyncIterable[InputDocument],
|
||||
page_queue: UpstreamAwareQueue,
|
||||
track_document_start,
|
||||
exception_event: asyncio.Event,
|
||||
doc_tracker: Dict[str, Any],
|
||||
doc_lock: asyncio.Lock,
|
||||
):
|
||||
"""Wrapper for document extraction with exception handling"""
|
||||
try:
|
||||
await self._extract_documents_safe(
|
||||
input_docs,
|
||||
page_queue,
|
||||
track_document_start,
|
||||
exception_event,
|
||||
)
|
||||
except Exception:
|
||||
exception_event.set()
|
||||
raise
|
||||
finally:
|
||||
async with doc_lock:
|
||||
doc_tracker["extraction_done"] = True
|
||||
# Send termination signal
|
||||
await page_queue.put(QueueTerminator("extraction"))
|
||||
|
||||
async def _process_pages_wrapper(
|
||||
self,
|
||||
page_queue: UpstreamAwareQueue,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
exception_event: asyncio.Event,
|
||||
):
|
||||
"""Wrapper for page processing with exception handling"""
|
||||
try:
|
||||
await self._process_pages_safe(page_queue, completed_queue, exception_event)
|
||||
except Exception:
|
||||
exception_event.set()
|
||||
raise
|
||||
finally:
|
||||
# Send termination signal
|
||||
await completed_queue.put(QueueTerminator("processing"))
|
||||
|
||||
async def _aggregate_results_wrapper(
|
||||
self,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
completed_docs: UpstreamAwareQueue,
|
||||
track_document_complete,
|
||||
exception_event: asyncio.Event,
|
||||
):
|
||||
"""Wrapper for result aggregation with exception handling"""
|
||||
try:
|
||||
await self._aggregate_results_safe(
|
||||
completed_queue,
|
||||
completed_docs,
|
||||
track_document_complete,
|
||||
exception_event,
|
||||
)
|
||||
except Exception:
|
||||
exception_event.set()
|
||||
raise
|
||||
|
||||
async def _yield_results(
|
||||
self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event
|
||||
):
|
||||
"""Yield results as they complete"""
|
||||
while True:
|
||||
if exception_event.is_set():
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(completed_docs.get(), timeout=1.0)
|
||||
if result is None:
|
||||
break
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
exception_event.set()
|
||||
raise
|
||||
|
||||
async def _extract_documents_safe(
|
||||
self,
|
||||
input_docs: AsyncIterable[InputDocument],
|
||||
page_queue: UpstreamAwareQueue,
|
||||
track_document_start,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Extract pages from documents with exception handling"""
|
||||
async for in_doc in input_docs:
|
||||
if exception_event.is_set():
|
||||
break
|
||||
|
||||
await track_document_start()
|
||||
conv_res = ConversionResult(input=in_doc)
|
||||
|
||||
# Validate backend
|
||||
if not isinstance(conv_res.input._backend, PdfDocumentBackend):
|
||||
conv_res.status = ConversionStatus.FAILURE
|
||||
await page_queue.put((None, conv_res)) # Signal failed document
|
||||
continue
|
||||
|
||||
try:
|
||||
# Initialize document
|
||||
total_pages = conv_res.input.page_count
|
||||
await self.page_tracker.register_document(conv_res, total_pages)
|
||||
|
||||
# Extract pages with limited concurrency
|
||||
semaphore = asyncio.Semaphore(
|
||||
self.pipeline_options.max_concurrent_extractions
|
||||
)
|
||||
|
||||
async def extract_page(page_no: int):
|
||||
if exception_event.is_set():
|
||||
return
|
||||
|
||||
async with semaphore:
|
||||
# Create page
|
||||
page = Page(page_no=page_no)
|
||||
conv_res.pages.append(page)
|
||||
|
||||
# Initialize page backend
|
||||
page._backend = await asyncio.to_thread(
|
||||
conv_res.input._backend.load_page, page_no
|
||||
)
|
||||
|
||||
if page._backend is not None and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
await self.page_tracker.track_page_loaded(page, conv_res)
|
||||
|
||||
# Send to processing queue
|
||||
await page_queue.put((page, conv_res))
|
||||
|
||||
# Extract all pages concurrently
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for i in range(total_pages):
|
||||
if exception_event.is_set():
|
||||
break
|
||||
start_page, end_page = conv_res.input.limits.page_range
|
||||
if (start_page - 1) <= i <= (end_page - 1):
|
||||
tg.create_task(extract_page(i))
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Failed to extract document {in_doc.file.name}: {e}")
|
||||
conv_res.status = ConversionStatus.FAILURE
|
||||
# Signal document failure
|
||||
await page_queue.put((None, conv_res))
|
||||
raise
|
||||
|
||||
async def _process_pages_safe(
|
||||
self,
|
||||
page_queue: UpstreamAwareQueue,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Process pages through model pipeline with proper termination"""
|
||||
# Process batches through each model stage
|
||||
preprocessing_queue = UpstreamAwareQueue()
|
||||
ocr_queue = UpstreamAwareQueue()
|
||||
layout_queue = UpstreamAwareQueue()
|
||||
table_queue = UpstreamAwareQueue()
|
||||
assemble_queue = UpstreamAwareQueue()
|
||||
|
||||
# Start processing stages using TaskGroup
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
# Preprocessing stage
|
||||
tg.create_task(
|
||||
self._batch_process_stage_safe(
|
||||
page_queue,
|
||||
preprocessing_queue,
|
||||
self._preprocess_batch,
|
||||
1,
|
||||
0, # No batching for preprocessing
|
||||
"preprocessing",
|
||||
exception_event,
|
||||
)
|
||||
)
|
||||
|
||||
# OCR stage
|
||||
tg.create_task(
|
||||
self._batch_process_stage_safe(
|
||||
preprocessing_queue,
|
||||
ocr_queue,
|
||||
self._ocr_batch,
|
||||
self.pipeline_options.ocr_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
"ocr",
|
||||
exception_event,
|
||||
)
|
||||
)
|
||||
|
||||
# Layout stage
|
||||
tg.create_task(
|
||||
self._batch_process_stage_safe(
|
||||
ocr_queue,
|
||||
layout_queue,
|
||||
self._layout_batch,
|
||||
self.pipeline_options.layout_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
"layout",
|
||||
exception_event,
|
||||
)
|
||||
)
|
||||
|
||||
# Table stage
|
||||
tg.create_task(
|
||||
self._batch_process_stage_safe(
|
||||
layout_queue,
|
||||
table_queue,
|
||||
self._table_batch,
|
||||
self.pipeline_options.table_batch_size,
|
||||
self.pipeline_options.batch_timeout_seconds,
|
||||
"table",
|
||||
exception_event,
|
||||
)
|
||||
)
|
||||
|
||||
# Assembly stage
|
||||
tg.create_task(
|
||||
self._batch_process_stage_safe(
|
||||
table_queue,
|
||||
assemble_queue,
|
||||
self._assemble_batch,
|
||||
1,
|
||||
0, # No batching for assembly
|
||||
"assembly",
|
||||
exception_event,
|
||||
)
|
||||
)
|
||||
|
||||
# Finalization stage
|
||||
tg.create_task(
|
||||
self._finalize_pages_safe(
|
||||
assemble_queue, completed_queue, exception_event
|
||||
)
|
||||
)
|
||||
|
||||
async def _batch_process_stage_safe(
|
||||
self,
|
||||
input_queue: UpstreamAwareQueue,
|
||||
output_queue: UpstreamAwareQueue,
|
||||
process_func,
|
||||
batch_size: int,
|
||||
timeout: float,
|
||||
stage_name: str,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Generic batch processing stage with proper termination handling"""
|
||||
batch = PageBatch()
|
||||
|
||||
try:
|
||||
while not exception_event.is_set():
|
||||
# Collect batch
|
||||
try:
|
||||
# Get first item or wait for timeout
|
||||
if not batch.pages:
|
||||
item = await input_queue.get()
|
||||
|
||||
# Check for termination
|
||||
if isinstance(item, QueueTerminator):
|
||||
# Propagate termination signal
|
||||
await output_queue.put(item)
|
||||
break
|
||||
|
||||
# Handle failed document signal
|
||||
if item[0] is None:
|
||||
# Pass through failure signal
|
||||
await output_queue.put(item)
|
||||
continue
|
||||
|
||||
batch.pages.append(item[0])
|
||||
batch.conv_results.append(item[1])
|
||||
|
||||
# Try to fill batch up to batch_size or until upstream exhausted
|
||||
while len(batch.pages) < batch_size:
|
||||
# Check if upstream is exhausted and we should process immediately
|
||||
if input_queue.should_stop_waiting_for_batch():
|
||||
break
|
||||
|
||||
# Calculate remaining time for batch timeout
|
||||
remaining_time = timeout - (time.time() - batch.start_time)
|
||||
|
||||
# If upstream is complete, use minimal timeout to drain quickly
|
||||
if input_queue.upstream_complete.is_set():
|
||||
remaining_time = min(remaining_time, 0.05)
|
||||
|
||||
if remaining_time <= 0:
|
||||
break
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(
|
||||
input_queue.get(), timeout=remaining_time
|
||||
)
|
||||
|
||||
# Check for termination
|
||||
if isinstance(item, QueueTerminator):
|
||||
# Put it back and process current batch
|
||||
await input_queue.put(item)
|
||||
break
|
||||
|
||||
# Handle failed document signal
|
||||
if item[0] is None:
|
||||
# Put it back and process current batch
|
||||
await input_queue.put(item)
|
||||
break
|
||||
|
||||
batch.pages.append(item[0])
|
||||
batch.conv_results.append(item[1])
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
# Process batch
|
||||
if batch.pages:
|
||||
processed = await process_func(batch)
|
||||
|
||||
# Send results to output queue
|
||||
for page, conv_res in processed:
|
||||
await output_queue.put((page, conv_res))
|
||||
|
||||
# Clear batch
|
||||
batch = PageBatch()
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Error in {stage_name} batch processing: {e}")
|
||||
# Send failed items downstream
|
||||
for page, conv_res in zip(batch.pages, batch.conv_results):
|
||||
await output_queue.put((page, conv_res))
|
||||
batch = PageBatch()
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# Set exception event and propagate termination
|
||||
exception_event.set()
|
||||
await output_queue.put(QueueTerminator(stage_name, error=e))
|
||||
raise
|
||||
|
||||
async def _preprocess_batch(
|
||||
self, batch: PageBatch
|
||||
) -> List[Tuple[Page, ConversionResult]]:
|
||||
"""Preprocess pages (no actual batching needed)"""
|
||||
results = []
|
||||
for page, conv_res in zip(batch.pages, batch.conv_results):
|
||||
processed_page = await asyncio.to_thread(
|
||||
lambda: next(iter(self.preprocessing_model(conv_res, [page])))
|
||||
)
|
||||
results.append((processed_page, conv_res))
|
||||
return results
|
||||
|
||||
async def _ocr_batch(self, batch: PageBatch) -> List[Tuple[Page, ConversionResult]]:
|
||||
"""Process OCR in batch"""
|
||||
# Group by conversion result for proper context
|
||||
grouped = defaultdict(list)
|
||||
for page, conv_res in zip(batch.pages, batch.conv_results):
|
||||
grouped[id(conv_res)].append(page)
|
||||
|
||||
results = []
|
||||
for conv_res_id, pages in grouped.items():
|
||||
# Find the conv_res
|
||||
conv_res = next(
|
||||
cr
|
||||
for p, cr in zip(batch.pages, batch.conv_results)
|
||||
if id(cr) == conv_res_id
|
||||
)
|
||||
|
||||
# Process batch through OCR model
|
||||
processed_pages = await asyncio.to_thread(
|
||||
lambda: list(self.ocr_model(conv_res, pages))
|
||||
)
|
||||
|
||||
for page in processed_pages:
|
||||
results.append((page, conv_res))
|
||||
|
||||
return results
|
||||
|
||||
async def _layout_batch(
|
||||
self, batch: PageBatch
|
||||
) -> List[Tuple[Page, ConversionResult]]:
|
||||
"""Process layout in batch"""
|
||||
# Similar batching as OCR
|
||||
grouped = defaultdict(list)
|
||||
for page, conv_res in zip(batch.pages, batch.conv_results):
|
||||
grouped[id(conv_res)].append(page)
|
||||
|
||||
results = []
|
||||
for conv_res_id, pages in grouped.items():
|
||||
conv_res = next(
|
||||
cr
|
||||
for p, cr in zip(batch.pages, batch.conv_results)
|
||||
if id(cr) == conv_res_id
|
||||
)
|
||||
|
||||
processed_pages = await asyncio.to_thread(
|
||||
lambda: list(self.layout_model(conv_res, pages))
|
||||
)
|
||||
|
||||
for page in processed_pages:
|
||||
results.append((page, conv_res))
|
||||
|
||||
return results
|
||||
|
||||
async def _table_batch(
|
||||
self, batch: PageBatch
|
||||
) -> List[Tuple[Page, ConversionResult]]:
|
||||
"""Process tables in batch"""
|
||||
grouped = defaultdict(list)
|
||||
for page, conv_res in zip(batch.pages, batch.conv_results):
|
||||
grouped[id(conv_res)].append(page)
|
||||
|
||||
results = []
|
||||
for conv_res_id, pages in grouped.items():
|
||||
conv_res = next(
|
||||
cr
|
||||
for p, cr in zip(batch.pages, batch.conv_results)
|
||||
if id(cr) == conv_res_id
|
||||
)
|
||||
|
||||
processed_pages = await asyncio.to_thread(
|
||||
lambda: list(self.table_model(conv_res, pages))
|
||||
)
|
||||
|
||||
for page in processed_pages:
|
||||
results.append((page, conv_res))
|
||||
|
||||
return results
|
||||
|
||||
async def _assemble_batch(
|
||||
self, batch: PageBatch
|
||||
) -> List[Tuple[Page, ConversionResult]]:
|
||||
"""Assemble pages (no actual batching needed)"""
|
||||
results = []
|
||||
for page, conv_res in zip(batch.pages, batch.conv_results):
|
||||
assembled_page = await asyncio.to_thread(
|
||||
lambda: next(iter(self.assemble_model(conv_res, [page])))
|
||||
)
|
||||
results.append((assembled_page, conv_res))
|
||||
return results
|
||||
|
||||
async def _finalize_pages_safe(
|
||||
self,
|
||||
input_queue: UpstreamAwareQueue,
|
||||
output_queue: UpstreamAwareQueue,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Finalize pages and track completion with proper termination"""
|
||||
try:
|
||||
while not exception_event.is_set():
|
||||
item = await input_queue.get()
|
||||
|
||||
# Check for termination
|
||||
if isinstance(item, QueueTerminator):
|
||||
# Propagate termination signal
|
||||
await output_queue.put(item)
|
||||
break
|
||||
|
||||
# Handle failed document signal
|
||||
if item[0] is None:
|
||||
# Pass through failure signal
|
||||
await output_queue.put(item)
|
||||
continue
|
||||
|
||||
page, conv_res = item
|
||||
|
||||
# Track page completion for resource cleanup
|
||||
await self.page_tracker.track_page_completion(page, conv_res)
|
||||
|
||||
# Send to output
|
||||
await output_queue.put((page, conv_res))
|
||||
|
||||
except Exception as e:
|
||||
exception_event.set()
|
||||
await output_queue.put(QueueTerminator("finalization", error=e))
|
||||
raise
|
||||
|
||||
async def _aggregate_results_safe(
|
||||
self,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
completed_docs: UpstreamAwareQueue,
|
||||
track_document_complete,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Aggregate completed pages into documents with proper termination"""
|
||||
doc_pages = defaultdict(list)
|
||||
failed_docs = set()
|
||||
|
||||
try:
|
||||
while not exception_event.is_set():
|
||||
item = await completed_queue.get()
|
||||
|
||||
# Check for termination
|
||||
if isinstance(item, QueueTerminator):
|
||||
# Finalize any remaining documents
|
||||
for conv_res_id, pages in doc_pages.items():
|
||||
if conv_res_id not in failed_docs:
|
||||
# Find conv_res from first page
|
||||
conv_res = pages[0][1]
|
||||
await self._finalize_document(conv_res)
|
||||
await completed_docs.put(conv_res)
|
||||
await track_document_complete()
|
||||
break
|
||||
|
||||
# Handle failed document signal
|
||||
if item[0] is None:
|
||||
conv_res = item[1]
|
||||
doc_id = id(conv_res)
|
||||
failed_docs.add(doc_id)
|
||||
# Send failed document immediately
|
||||
await completed_docs.put(conv_res)
|
||||
await track_document_complete()
|
||||
continue
|
||||
|
||||
page, conv_res = item
|
||||
doc_id = id(conv_res)
|
||||
|
||||
if doc_id not in failed_docs:
|
||||
doc_pages[doc_id].append((page, conv_res))
|
||||
|
||||
# Check if document is complete
|
||||
if len(doc_pages[doc_id]) == len(conv_res.pages):
|
||||
await self._finalize_document(conv_res)
|
||||
await completed_docs.put(conv_res)
|
||||
await track_document_complete()
|
||||
del doc_pages[doc_id]
|
||||
|
||||
except Exception:
|
||||
exception_event.set()
|
||||
# Try to send any completed documents before failing
|
||||
for conv_res_id, pages in doc_pages.items():
|
||||
if conv_res_id not in failed_docs and pages:
|
||||
conv_res = pages[0][1]
|
||||
conv_res.status = ConversionStatus.PARTIAL_SUCCESS
|
||||
await completed_docs.put(conv_res)
|
||||
await track_document_complete()
|
||||
raise
|
||||
|
||||
async def _finalize_document(self, conv_res: ConversionResult) -> None:
|
||||
"""Finalize a complete document (same as StandardPdfPipeline._assemble_document)"""
|
||||
# This matches the logic from StandardPdfPipeline
|
||||
|
144
docling/pipeline/graph.py
Normal file
144
docling/pipeline/graph.py
Normal file
@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterable, Dict, List, Literal, Optional
|
||||
|
||||
# Sentinel to signal stream completion
|
||||
STOP_SENTINEL = object()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StreamItem:
|
||||
"""
|
||||
A wrapper for data flowing through the pipeline, maintaining a link
|
||||
to the original conversion result context.
|
||||
"""
|
||||
|
||||
payload: Any
|
||||
conv_res_id: int
|
||||
conv_res: Any # Opaque reference to ConversionResult
|
||||
|
||||
|
||||
class PipelineStage(ABC):
|
||||
"""A single, encapsulated step in a processing pipeline graph."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.input_queues: Dict[str, asyncio.Queue] = {}
|
||||
self.output_queues: Dict[str, List[asyncio.Queue]] = {}
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
@abstractmethod
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
The core execution logic for the stage. This method is responsible for
|
||||
consuming from input queues, processing data, and putting results into
|
||||
output queues.
|
||||
"""
|
||||
|
||||
async def _send_to_outputs(self, channel: str, items: List[StreamItem] | List[Any]):
|
||||
"""Helper to send processed items to all connected output queues."""
|
||||
if channel in self.output_queues:
|
||||
for queue in self.output_queues[channel]:
|
||||
for item in items:
|
||||
await queue.put(item)
|
||||
|
||||
async def _signal_downstream_completion(self):
|
||||
"""Signal that this stage is done processing to all output channels."""
|
||||
for channel_queues in self.output_queues.values():
|
||||
for queue in channel_queues:
|
||||
await queue.put(STOP_SENTINEL)
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
return self._loop
|
||||
|
||||
|
||||
class GraphRunner:
|
||||
"""Connects stages and runs the pipeline graph."""
|
||||
|
||||
def __init__(self, stages: List[PipelineStage], edges: List[Dict[str, str]]):
|
||||
self._stages = {s.name: s for s in stages}
|
||||
self._edges = edges
|
||||
|
||||
def _wire_graph(self, queue_max_size: int):
|
||||
"""Create queues for edges and connect them to stage inputs and outputs."""
|
||||
for edge in self._edges:
|
||||
from_stage, from_output = edge["from_stage"], edge["from_output"]
|
||||
to_stage, to_input = edge["to_stage"], edge["to_input"]
|
||||
|
||||
queue = asyncio.Queue(maxsize=queue_max_size)
|
||||
|
||||
# Connect to source stage's output
|
||||
self._stages[from_stage].output_queues.setdefault(from_output, []).append(
|
||||
queue
|
||||
)
|
||||
|
||||
# Connect to destination stage's input
|
||||
self._stages[to_stage].input_queues[to_input] = queue
|
||||
|
||||
async def _run_source(
|
||||
self,
|
||||
source_stream: AsyncIterable[Any],
|
||||
source_stage: str,
|
||||
source_channel: str,
|
||||
):
|
||||
"""Feed the graph from an external async iterable."""
|
||||
output_queues = self._stages[source_stage].output_queues.get(source_channel, [])
|
||||
async for item in source_stream:
|
||||
for queue in output_queues:
|
||||
await queue.put(item)
|
||||
# Signal completion to all downstream queues
|
||||
for queue in output_queues:
|
||||
await queue.put(STOP_SENTINEL)
|
||||
|
||||
async def _run_sink(self, sink_stage: str, sink_channel: str) -> AsyncIterable[Any]:
|
||||
"""Yield results from the graph's final output queue."""
|
||||
queue = self._stages[sink_stage].input_queues[sink_channel]
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is STOP_SENTINEL:
|
||||
break
|
||||
yield item
|
||||
await queue.put(STOP_SENTINEL) # Allow other sinks to terminate
|
||||
|
||||
async def run(
|
||||
self,
|
||||
source_stream: AsyncIterable,
|
||||
source_config: Dict[str, str],
|
||||
sink_config: Dict[str, str],
|
||||
queue_max_size: int = 32,
|
||||
) -> AsyncIterable:
|
||||
"""
|
||||
Executes the entire pipeline graph.
|
||||
|
||||
Args:
|
||||
source_stream: The initial async iterable to feed the graph.
|
||||
source_config: Dictionary with "stage" and "channel" for the entry point.
|
||||
sink_config: Dictionary with "stage" and "channel" for the exit point.
|
||||
queue_max_size: The max size for the internal asyncio.Queues.
|
||||
"""
|
||||
self._wire_graph(queue_max_size)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
# Create a task for the source feeder
|
||||
tg.create_task(
|
||||
self._run_source(
|
||||
source_stream, source_config["stage"], source_config["channel"]
|
||||
)
|
||||
)
|
||||
|
||||
# Create tasks for all pipeline stages
|
||||
for stage in self._stages.values():
|
||||
tg.create_task(stage.run())
|
||||
|
||||
# Yield results from the sink
|
||||
async for result in self._run_sink(
|
||||
sink_config["stage"], sink_config["channel"]
|
||||
):
|
||||
yield result
|
125
docling/pipeline/resource_manager.py
Normal file
125
docling/pipeline/resource_manager.py
Normal file
@ -0,0 +1,125 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.pipeline.async_base_pipeline import DocumentTracker
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncPageTracker:
|
||||
"""Manages page backend lifecycle across documents"""
|
||||
|
||||
_doc_trackers: Dict[str, DocumentTracker] = field(default_factory=dict)
|
||||
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
keep_images: bool = False
|
||||
keep_backend: bool = False
|
||||
|
||||
async def register_document(
|
||||
self, conv_res: ConversionResult, total_pages: int
|
||||
) -> str:
|
||||
"""Register a new document for tracking"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
self._doc_trackers[doc_id] = DocumentTracker(
|
||||
doc_id=doc_id, total_pages=total_pages, conv_result=conv_res
|
||||
)
|
||||
return doc_id
|
||||
|
||||
async def track_page_loaded(self, page: Page, conv_res: ConversionResult) -> None:
|
||||
"""Track when a page backend is loaded"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id in self._doc_trackers and page._backend is not None:
|
||||
self._doc_trackers[doc_id].page_backends[page.page_no] = page._backend
|
||||
|
||||
async def track_page_completion(
|
||||
self, page: Page, conv_res: ConversionResult
|
||||
) -> bool:
|
||||
"""Track page completion and cleanup when all pages done"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id not in self._doc_trackers:
|
||||
_log.warning(f"Document {doc_id} not registered for tracking")
|
||||
return False
|
||||
|
||||
tracker = self._doc_trackers[doc_id]
|
||||
tracker.processed_pages += 1
|
||||
|
||||
# Clear this page's image cache if needed
|
||||
if not self.keep_images:
|
||||
page._image_cache = {}
|
||||
|
||||
# If all pages from this document are processed, cleanup
|
||||
if tracker.processed_pages == tracker.total_pages:
|
||||
await self._cleanup_document_resources(tracker)
|
||||
del self._doc_trackers[doc_id]
|
||||
return True # Document is complete
|
||||
|
||||
return False # Document is not yet complete
|
||||
|
||||
async def _cleanup_document_resources(self, tracker: DocumentTracker) -> None:
|
||||
"""Cleanup all resources for a completed document"""
|
||||
if not self.keep_backend:
|
||||
# Unload all page backends for this document
|
||||
for page_no, backend in tracker.page_backends.items():
|
||||
if backend is not None:
|
||||
try:
|
||||
# Run unload in thread to avoid blocking
|
||||
await asyncio.to_thread(backend.unload)
|
||||
except Exception as e:
|
||||
_log.warning(
|
||||
f"Failed to unload backend for page {page_no}: {e}"
|
||||
)
|
||||
|
||||
tracker.page_backends.clear()
|
||||
_log.debug(f"Cleaned up resources for document {tracker.doc_id}")
|
||||
|
||||
async def cleanup_all(self) -> None:
|
||||
"""Cleanup all tracked documents - for shutdown"""
|
||||
async with self._lock:
|
||||
for tracker in self._doc_trackers.values():
|
||||
await self._cleanup_document_resources(tracker)
|
||||
self._doc_trackers.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversionResultAccumulator:
|
||||
"""Accumulates updates to ConversionResult without immediate mutation"""
|
||||
|
||||
_updates: Dict[str, Dict] = field(default_factory=dict) # doc_id -> updates
|
||||
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
|
||||
async def accumulate_page_result(
|
||||
self, page_no: int, conv_res: ConversionResult, updates: Dict
|
||||
) -> None:
|
||||
"""Accumulate updates for later application"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id not in self._updates:
|
||||
self._updates[doc_id] = {}
|
||||
|
||||
if page_no not in self._updates[doc_id]:
|
||||
self._updates[doc_id][page_no] = {}
|
||||
|
||||
self._updates[doc_id][page_no].update(updates)
|
||||
|
||||
async def flush_to_conv_res(self, conv_res: ConversionResult) -> None:
|
||||
"""Apply all accumulated updates atomically"""
|
||||
async with self._lock:
|
||||
doc_id = str(id(conv_res))
|
||||
if doc_id in self._updates:
|
||||
# Apply updates
|
||||
for page_no, updates in self._updates[doc_id].items():
|
||||
# Find the page and apply updates
|
||||
for page in conv_res.pages:
|
||||
if page.page_no == page_no:
|
||||
for key, value in updates.items():
|
||||
setattr(page, key, value)
|
||||
break
|
||||
|
||||
del self._updates[doc_id]
|
273
docling/pipeline/stages.py
Normal file
273
docling/pipeline/stages.py
Normal file
@ -0,0 +1,273 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, List
|
||||
|
||||
from docling.datamodel.document import ConversionResult, InputDocument, Page
|
||||
from docling.pipeline.graph import STOP_SENTINEL, PipelineStage, StreamItem
|
||||
|
||||
|
||||
class SourceStage(PipelineStage):
|
||||
"""A placeholder stage to represent the entry point of the graph."""
|
||||
|
||||
async def run(self) -> None:
|
||||
# This stage is driven by the GraphRunner's _run_source method
|
||||
# and does not have its own execution loop.
|
||||
pass
|
||||
|
||||
|
||||
class SinkStage(PipelineStage):
|
||||
"""A placeholder stage to represent the exit point of the graph."""
|
||||
|
||||
async def run(self) -> None:
|
||||
# This stage is read by the GraphRunner's _run_sink method
|
||||
# and does not have its own execution loop.
|
||||
pass
|
||||
|
||||
|
||||
class ExtractionStage(PipelineStage):
|
||||
"""Extracts pages from documents and tracks them."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
page_tracker: Any,
|
||||
max_concurrent_extractions: int,
|
||||
):
|
||||
super().__init__(name)
|
||||
self.page_tracker = page_tracker
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent_extractions)
|
||||
self.input_channel = "in"
|
||||
self.output_channel = "out"
|
||||
self.failure_channel = "fail"
|
||||
|
||||
async def _extract_page(
|
||||
self, page_no: int, conv_res: ConversionResult
|
||||
) -> StreamItem | None:
|
||||
"""Coroutine to extract a single page."""
|
||||
try:
|
||||
async with self.semaphore:
|
||||
page = Page(page_no=page_no)
|
||||
conv_res.pages.append(page)
|
||||
|
||||
page._backend = await self.loop.run_in_executor(
|
||||
None, conv_res.input._backend.load_page, page_no
|
||||
)
|
||||
|
||||
if page._backend and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
await self.page_tracker.track_page_loaded(page, conv_res)
|
||||
return StreamItem(
|
||||
payload=page, conv_res_id=id(conv_res), conv_res=conv_res
|
||||
)
|
||||
except Exception:
|
||||
# In case of page-level error, we might log it but continue
|
||||
# For now, we don't propagate failure here, but in the document level
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _process_document(self, in_doc: InputDocument):
|
||||
"""Processes a single document, extracting all its pages."""
|
||||
conv_res = ConversionResult(input=in_doc)
|
||||
|
||||
try:
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
|
||||
if not isinstance(in_doc._backend, PdfDocumentBackend):
|
||||
raise TypeError("Backend is not a valid PdfDocumentBackend")
|
||||
|
||||
total_pages = in_doc.page_count
|
||||
await self.page_tracker.register_document(conv_res, total_pages)
|
||||
|
||||
start_page, end_page = conv_res.input.limits.page_range
|
||||
page_indices_to_extract = [
|
||||
i for i in range(total_pages) if (start_page - 1) <= i <= (end_page - 1)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
self.loop.create_task(self._extract_page(i, conv_res))
|
||||
for i in page_indices_to_extract
|
||||
]
|
||||
pages_extracted = await asyncio.gather(*tasks)
|
||||
|
||||
await self._send_to_outputs(
|
||||
self.output_channel, [p for p in pages_extracted if p]
|
||||
)
|
||||
|
||||
except Exception:
|
||||
conv_res.status = "FAILURE"
|
||||
await self._send_to_outputs(self.failure_channel, [conv_res])
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Main loop to consume documents and launch extraction tasks."""
|
||||
q_in = self.input_queues[self.input_channel]
|
||||
while True:
|
||||
doc = await q_in.get()
|
||||
if doc is STOP_SENTINEL:
|
||||
await self._signal_downstream_completion()
|
||||
break
|
||||
await self._process_document(doc)
|
||||
|
||||
|
||||
class PageProcessorStage(PipelineStage):
|
||||
"""Applies a synchronous, 1-to-1 processing function to each page."""
|
||||
|
||||
def __init__(self, name: str, model: Any):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
self.input_channel = "in"
|
||||
self.output_channel = "out"
|
||||
|
||||
async def run(self) -> None:
|
||||
q_in = self.input_queues[self.input_channel]
|
||||
while True:
|
||||
item = await q_in.get()
|
||||
if item is STOP_SENTINEL:
|
||||
await self._signal_downstream_completion()
|
||||
break
|
||||
|
||||
# The model call is sync, run in thread to avoid blocking event loop
|
||||
processed_page = await self.loop.run_in_executor(
|
||||
None, lambda: next(iter(self.model(item.conv_res, [item.payload])))
|
||||
)
|
||||
item.payload = processed_page
|
||||
await self._send_to_outputs(self.output_channel, [item])
|
||||
|
||||
|
||||
class BatchProcessorStage(PipelineStage):
|
||||
"""Batches items and applies a synchronous model to the batch."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model: Any,
|
||||
batch_size: int,
|
||||
batch_timeout: float,
|
||||
):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.batch_timeout = batch_timeout
|
||||
self.input_channel = "in"
|
||||
self.output_channel = "out"
|
||||
|
||||
async def _collect_batch(self, q_in: asyncio.Queue) -> List[StreamItem] | None:
|
||||
"""Collects a batch of items from the input queue with a timeout."""
|
||||
try:
|
||||
# Wait for the first item without a timeout
|
||||
first_item = await q_in.get()
|
||||
if first_item is STOP_SENTINEL:
|
||||
return None # End of stream
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
|
||||
batch = [first_item]
|
||||
start_time = self.loop.time()
|
||||
|
||||
while len(batch) < self.batch_size:
|
||||
timeout = self.batch_timeout - (self.loop.time() - start_time)
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
item = await asyncio.wait_for(q_in.get(), timeout)
|
||||
if item is STOP_SENTINEL:
|
||||
# Put sentinel back for other potential consumers or the main loop
|
||||
await q_in.put(STOP_SENTINEL)
|
||||
break
|
||||
batch.append(item)
|
||||
except asyncio.TimeoutError:
|
||||
break # Batching timeout reached
|
||||
return batch
|
||||
|
||||
async def run(self) -> None:
|
||||
q_in = self.input_queues[self.input_channel]
|
||||
while True:
|
||||
batch = await self._collect_batch(q_in)
|
||||
|
||||
if not batch: # This can be None or an empty list.
|
||||
await self._signal_downstream_completion()
|
||||
break
|
||||
|
||||
# Group pages by their original ConversionResult
|
||||
grouped_by_doc = defaultdict(list)
|
||||
for item in batch:
|
||||
grouped_by_doc[item.conv_res_id].append(item)
|
||||
|
||||
processed_items = []
|
||||
for conv_res_id, items in grouped_by_doc.items():
|
||||
conv_res = items[0].conv_res
|
||||
pages = [item.payload for item in items]
|
||||
|
||||
# The model call is sync, run in thread
|
||||
processed_pages = await self.loop.run_in_executor(
|
||||
None, lambda: list(self.model(conv_res, pages))
|
||||
)
|
||||
|
||||
# Re-wrap the processed pages into StreamItems
|
||||
for i, page in enumerate(processed_pages):
|
||||
processed_items.append(
|
||||
StreamItem(
|
||||
payload=page,
|
||||
conv_res_id=items[i].conv_res_id,
|
||||
conv_res=items[i].conv_res,
|
||||
)
|
||||
)
|
||||
|
||||
await self._send_to_outputs(self.output_channel, processed_items)
|
||||
|
||||
|
||||
class AggregationStage(PipelineStage):
|
||||
"""Aggregates processed pages back into completed documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
page_tracker: Any,
|
||||
finalizer_func: Callable[[ConversionResult], Coroutine],
|
||||
):
|
||||
super().__init__(name)
|
||||
self.page_tracker = page_tracker
|
||||
self.finalizer_func = finalizer_func
|
||||
self.success_channel = "in"
|
||||
self.failure_channel = "fail"
|
||||
self.output_channel = "out"
|
||||
|
||||
async def run(self) -> None:
|
||||
success_q = self.input_queues[self.success_channel]
|
||||
failure_q = self.input_queues.get(self.failure_channel)
|
||||
doc_completers: Dict[int, asyncio.Future] = {}
|
||||
|
||||
async def handle_successes():
|
||||
while True:
|
||||
item = await success_q.get()
|
||||
if item is STOP_SENTINEL:
|
||||
break
|
||||
is_doc_complete = await self.page_tracker.track_page_completion(
|
||||
item.payload, item.conv_res
|
||||
)
|
||||
if is_doc_complete:
|
||||
await self.finalizer_func(item.conv_res)
|
||||
await self._send_to_outputs(self.output_channel, [item.conv_res])
|
||||
if item.conv_res_id in doc_completers:
|
||||
doc_completers.pop(item.conv_res_id).set_result(True)
|
||||
|
||||
async def handle_failures():
|
||||
if failure_q is None:
|
||||
return # No failure channel, nothing to do
|
||||
while True:
|
||||
failed_res = await failure_q.get()
|
||||
if failed_res is STOP_SENTINEL:
|
||||
break
|
||||
await self._send_to_outputs(self.output_channel, [failed_res])
|
||||
if id(failed_res) in doc_completers:
|
||||
doc_completers.pop(id(failed_res)).set_result(True)
|
||||
|
||||
# Create tasks only for channels that exist
|
||||
tasks = [handle_successes()]
|
||||
if failure_q is not None:
|
||||
tasks.append(handle_failures())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
await self._signal_downstream_completion()
|
Loading…
Reference in New Issue
Block a user