UpstreamAwareQueue

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-07-15 20:09:05 +02:00
parent f56de726f3
commit ef25d03bc8
3 changed files with 233 additions and 1866 deletions

View File

@ -54,6 +54,38 @@ class QueueTerminator:
error: Optional[Exception] = None error: Optional[Exception] = None
class UpstreamAwareQueue(asyncio.Queue):
"""Queue that tracks upstream completion to optimize batch processing"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.upstream_complete = asyncio.Event()
self.items_in_flight = 0
async def put(self, item):
"""Track items and upstream completion signals"""
if isinstance(item, QueueTerminator):
self.upstream_complete.set()
else:
self.items_in_flight += 1
await super().put(item)
async def get(self):
"""Track when items are consumed"""
item = await super().get()
if not isinstance(item, QueueTerminator):
self.items_in_flight -= 1
return item
def should_stop_waiting_for_batch(self) -> bool:
"""Determine if we should stop waiting and process current batch"""
return (
self.upstream_complete.is_set()
and self.items_in_flight == 0
and self.empty()
)
class AsyncStandardPdfPipeline(AsyncPipeline): class AsyncStandardPdfPipeline(AsyncPipeline):
"""Async pipeline implementation with cross-document batching using structured concurrency""" """Async pipeline implementation with cross-document batching using structured concurrency"""
@ -187,9 +219,11 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
) -> AsyncIterable[ConversionResult]: ) -> AsyncIterable[ConversionResult]:
"""Main async processing with structured concurrency and proper exception handling""" """Main async processing with structured concurrency and proper exception handling"""
# Create queues for pipeline stages # Create queues for pipeline stages
page_queue = asyncio.Queue(maxsize=self.pipeline_options.extraction_queue_size) page_queue = UpstreamAwareQueue(
completed_queue = asyncio.Queue() maxsize=self.pipeline_options.extraction_queue_size
completed_docs = asyncio.Queue() )
completed_queue = UpstreamAwareQueue()
completed_docs = UpstreamAwareQueue()
# Track active documents for proper termination # Track active documents for proper termination
doc_tracker = {"active_docs": 0, "extraction_done": False} doc_tracker = {"active_docs": 0, "extraction_done": False}
@ -254,7 +288,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _extract_documents_wrapper( async def _extract_documents_wrapper(
self, self,
input_docs: AsyncIterable[InputDocument], input_docs: AsyncIterable[InputDocument],
page_queue: asyncio.Queue, page_queue: UpstreamAwareQueue,
track_document_start, track_document_start,
exception_event: asyncio.Event, exception_event: asyncio.Event,
doc_tracker: Dict[str, Any], doc_tracker: Dict[str, Any],
@ -279,8 +313,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _process_pages_wrapper( async def _process_pages_wrapper(
self, self,
page_queue: asyncio.Queue, page_queue: UpstreamAwareQueue,
completed_queue: asyncio.Queue, completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event, exception_event: asyncio.Event,
): ):
"""Wrapper for page processing with exception handling""" """Wrapper for page processing with exception handling"""
@ -295,8 +329,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _aggregate_results_wrapper( async def _aggregate_results_wrapper(
self, self,
completed_queue: asyncio.Queue, completed_queue: UpstreamAwareQueue,
completed_docs: asyncio.Queue, completed_docs: UpstreamAwareQueue,
track_document_complete, track_document_complete,
exception_event: asyncio.Event, exception_event: asyncio.Event,
): ):
@ -313,7 +347,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
raise raise
async def _yield_results( async def _yield_results(
self, completed_docs: asyncio.Queue, exception_event: asyncio.Event self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event
): ):
"""Yield results as they complete""" """Yield results as they complete"""
while True: while True:
@ -334,7 +368,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _extract_documents_safe( async def _extract_documents_safe(
self, self,
input_docs: AsyncIterable[InputDocument], input_docs: AsyncIterable[InputDocument],
page_queue: asyncio.Queue, page_queue: UpstreamAwareQueue,
track_document_start, track_document_start,
exception_event: asyncio.Event, exception_event: asyncio.Event,
) -> None: ) -> None:
@ -401,17 +435,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _process_pages_safe( async def _process_pages_safe(
self, self,
page_queue: asyncio.Queue, page_queue: UpstreamAwareQueue,
completed_queue: asyncio.Queue, completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event, exception_event: asyncio.Event,
) -> None: ) -> None:
"""Process pages through model pipeline with proper termination""" """Process pages through model pipeline with proper termination"""
# Process batches through each model stage # Process batches through each model stage
preprocessing_queue = asyncio.Queue() preprocessing_queue = UpstreamAwareQueue()
ocr_queue = asyncio.Queue() ocr_queue = UpstreamAwareQueue()
layout_queue = asyncio.Queue() layout_queue = UpstreamAwareQueue()
table_queue = asyncio.Queue() table_queue = UpstreamAwareQueue()
assemble_queue = asyncio.Queue() assemble_queue = UpstreamAwareQueue()
# Start processing stages using TaskGroup # Start processing stages using TaskGroup
async with asyncio.TaskGroup() as tg: async with asyncio.TaskGroup() as tg:
@ -489,8 +523,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _batch_process_stage_safe( async def _batch_process_stage_safe(
self, self,
input_queue: asyncio.Queue, input_queue: UpstreamAwareQueue,
output_queue: asyncio.Queue, output_queue: UpstreamAwareQueue,
process_func, process_func,
batch_size: int, batch_size: int,
timeout: float, timeout: float,
@ -523,9 +557,19 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
batch.pages.append(item[0]) batch.pages.append(item[0])
batch.conv_results.append(item[1]) batch.conv_results.append(item[1])
# Try to fill batch up to batch_size # Try to fill batch up to batch_size or until upstream exhausted
while len(batch.pages) < batch_size: while len(batch.pages) < batch_size:
# Check if upstream is exhausted and we should process immediately
if input_queue.should_stop_waiting_for_batch():
break
# Calculate remaining time for batch timeout
remaining_time = timeout - (time.time() - batch.start_time) remaining_time = timeout - (time.time() - batch.start_time)
# If upstream is complete, use minimal timeout to drain quickly
if input_queue.upstream_complete.is_set():
remaining_time = min(remaining_time, 0.05)
if remaining_time <= 0: if remaining_time <= 0:
break break
@ -679,8 +723,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _finalize_pages_safe( async def _finalize_pages_safe(
self, self,
input_queue: asyncio.Queue, input_queue: UpstreamAwareQueue,
output_queue: asyncio.Queue, output_queue: UpstreamAwareQueue,
exception_event: asyncio.Event, exception_event: asyncio.Event,
) -> None: ) -> None:
"""Finalize pages and track completion with proper termination""" """Finalize pages and track completion with proper termination"""
@ -715,8 +759,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _aggregate_results_safe( async def _aggregate_results_safe(
self, self,
completed_queue: asyncio.Queue, completed_queue: UpstreamAwareQueue,
completed_docs: asyncio.Queue, completed_docs: UpstreamAwareQueue,
track_document_complete, track_document_complete,
exception_event: asyncio.Event, exception_event: asyncio.Event,
) -> None: ) -> None:

View File

@ -41,7 +41,7 @@ authors = [
{ name = "Panos Vagenas", email = "pva@zurich.ibm.com" }, { name = "Panos Vagenas", email = "pva@zurich.ibm.com" },
{ name = "Peter Staar", email = "taa@zurich.ibm.com" }, { name = "Peter Staar", email = "taa@zurich.ibm.com" },
] ]
requires-python = '>=3.9,<4.0' requires-python = '>=3.11,<4.0'
dependencies = [ dependencies = [
'pydantic (>=2.0.0,<3.0.0)', 'pydantic (>=2.0.0,<3.0.0)',
'docling-core[chunking] (>=2.42.0,<3.0.0)', 'docling-core[chunking] (>=2.42.0,<3.0.0)',

2005
uv.lock generated

File diff suppressed because it is too large Load Diff