UpstreamAwareQueue

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-07-15 20:09:05 +02:00
parent f56de726f3
commit ef25d03bc8
3 changed files with 233 additions and 1866 deletions

View File

@ -54,6 +54,38 @@ class QueueTerminator:
error: Optional[Exception] = None
class UpstreamAwareQueue(asyncio.Queue):
"""Queue that tracks upstream completion to optimize batch processing"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.upstream_complete = asyncio.Event()
self.items_in_flight = 0
async def put(self, item):
"""Track items and upstream completion signals"""
if isinstance(item, QueueTerminator):
self.upstream_complete.set()
else:
self.items_in_flight += 1
await super().put(item)
async def get(self):
"""Track when items are consumed"""
item = await super().get()
if not isinstance(item, QueueTerminator):
self.items_in_flight -= 1
return item
def should_stop_waiting_for_batch(self) -> bool:
"""Determine if we should stop waiting and process current batch"""
return (
self.upstream_complete.is_set()
and self.items_in_flight == 0
and self.empty()
)
class AsyncStandardPdfPipeline(AsyncPipeline):
"""Async pipeline implementation with cross-document batching using structured concurrency"""
@ -187,9 +219,11 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
) -> AsyncIterable[ConversionResult]:
"""Main async processing with structured concurrency and proper exception handling"""
# Create queues for pipeline stages
page_queue = asyncio.Queue(maxsize=self.pipeline_options.extraction_queue_size)
completed_queue = asyncio.Queue()
completed_docs = asyncio.Queue()
page_queue = UpstreamAwareQueue(
maxsize=self.pipeline_options.extraction_queue_size
)
completed_queue = UpstreamAwareQueue()
completed_docs = UpstreamAwareQueue()
# Track active documents for proper termination
doc_tracker = {"active_docs": 0, "extraction_done": False}
@ -254,7 +288,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _extract_documents_wrapper(
self,
input_docs: AsyncIterable[InputDocument],
page_queue: asyncio.Queue,
page_queue: UpstreamAwareQueue,
track_document_start,
exception_event: asyncio.Event,
doc_tracker: Dict[str, Any],
@ -279,8 +313,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _process_pages_wrapper(
self,
page_queue: asyncio.Queue,
completed_queue: asyncio.Queue,
page_queue: UpstreamAwareQueue,
completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
):
"""Wrapper for page processing with exception handling"""
@ -295,8 +329,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _aggregate_results_wrapper(
self,
completed_queue: asyncio.Queue,
completed_docs: asyncio.Queue,
completed_queue: UpstreamAwareQueue,
completed_docs: UpstreamAwareQueue,
track_document_complete,
exception_event: asyncio.Event,
):
@ -313,7 +347,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
raise
async def _yield_results(
self, completed_docs: asyncio.Queue, exception_event: asyncio.Event
self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event
):
"""Yield results as they complete"""
while True:
@ -334,7 +368,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _extract_documents_safe(
self,
input_docs: AsyncIterable[InputDocument],
page_queue: asyncio.Queue,
page_queue: UpstreamAwareQueue,
track_document_start,
exception_event: asyncio.Event,
) -> None:
@ -401,17 +435,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _process_pages_safe(
self,
page_queue: asyncio.Queue,
completed_queue: asyncio.Queue,
page_queue: UpstreamAwareQueue,
completed_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
) -> None:
"""Process pages through model pipeline with proper termination"""
# Process batches through each model stage
preprocessing_queue = asyncio.Queue()
ocr_queue = asyncio.Queue()
layout_queue = asyncio.Queue()
table_queue = asyncio.Queue()
assemble_queue = asyncio.Queue()
preprocessing_queue = UpstreamAwareQueue()
ocr_queue = UpstreamAwareQueue()
layout_queue = UpstreamAwareQueue()
table_queue = UpstreamAwareQueue()
assemble_queue = UpstreamAwareQueue()
# Start processing stages using TaskGroup
async with asyncio.TaskGroup() as tg:
@ -489,8 +523,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _batch_process_stage_safe(
self,
input_queue: asyncio.Queue,
output_queue: asyncio.Queue,
input_queue: UpstreamAwareQueue,
output_queue: UpstreamAwareQueue,
process_func,
batch_size: int,
timeout: float,
@ -523,9 +557,19 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
batch.pages.append(item[0])
batch.conv_results.append(item[1])
# Try to fill batch up to batch_size
# Try to fill batch up to batch_size or until upstream exhausted
while len(batch.pages) < batch_size:
# Check if upstream is exhausted and we should process immediately
if input_queue.should_stop_waiting_for_batch():
break
# Calculate remaining time for batch timeout
remaining_time = timeout - (time.time() - batch.start_time)
# If upstream is complete, use minimal timeout to drain quickly
if input_queue.upstream_complete.is_set():
remaining_time = min(remaining_time, 0.05)
if remaining_time <= 0:
break
@ -679,8 +723,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _finalize_pages_safe(
self,
input_queue: asyncio.Queue,
output_queue: asyncio.Queue,
input_queue: UpstreamAwareQueue,
output_queue: UpstreamAwareQueue,
exception_event: asyncio.Event,
) -> None:
"""Finalize pages and track completion with proper termination"""
@ -715,8 +759,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
async def _aggregate_results_safe(
self,
completed_queue: asyncio.Queue,
completed_docs: asyncio.Queue,
completed_queue: UpstreamAwareQueue,
completed_docs: UpstreamAwareQueue,
track_document_complete,
exception_event: asyncio.Event,
) -> None:

View File

@ -41,7 +41,7 @@ authors = [
{ name = "Panos Vagenas", email = "pva@zurich.ibm.com" },
{ name = "Peter Staar", email = "taa@zurich.ibm.com" },
]
requires-python = '>=3.9,<4.0'
requires-python = '>=3.11,<4.0'
dependencies = [
'pydantic (>=2.0.0,<3.0.0)',
'docling-core[chunking] (>=2.42.0,<3.0.0)',

2005
uv.lock generated

File diff suppressed because it is too large Load Diff