mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-24 19:14:23 +00:00
UpstreamAwareQueue
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
f56de726f3
commit
ef25d03bc8
@ -54,6 +54,38 @@ class QueueTerminator:
|
||||
error: Optional[Exception] = None
|
||||
|
||||
|
||||
class UpstreamAwareQueue(asyncio.Queue):
|
||||
"""Queue that tracks upstream completion to optimize batch processing"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.upstream_complete = asyncio.Event()
|
||||
self.items_in_flight = 0
|
||||
|
||||
async def put(self, item):
|
||||
"""Track items and upstream completion signals"""
|
||||
if isinstance(item, QueueTerminator):
|
||||
self.upstream_complete.set()
|
||||
else:
|
||||
self.items_in_flight += 1
|
||||
await super().put(item)
|
||||
|
||||
async def get(self):
|
||||
"""Track when items are consumed"""
|
||||
item = await super().get()
|
||||
if not isinstance(item, QueueTerminator):
|
||||
self.items_in_flight -= 1
|
||||
return item
|
||||
|
||||
def should_stop_waiting_for_batch(self) -> bool:
|
||||
"""Determine if we should stop waiting and process current batch"""
|
||||
return (
|
||||
self.upstream_complete.is_set()
|
||||
and self.items_in_flight == 0
|
||||
and self.empty()
|
||||
)
|
||||
|
||||
|
||||
class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
"""Async pipeline implementation with cross-document batching using structured concurrency"""
|
||||
|
||||
@ -187,9 +219,11 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
) -> AsyncIterable[ConversionResult]:
|
||||
"""Main async processing with structured concurrency and proper exception handling"""
|
||||
# Create queues for pipeline stages
|
||||
page_queue = asyncio.Queue(maxsize=self.pipeline_options.extraction_queue_size)
|
||||
completed_queue = asyncio.Queue()
|
||||
completed_docs = asyncio.Queue()
|
||||
page_queue = UpstreamAwareQueue(
|
||||
maxsize=self.pipeline_options.extraction_queue_size
|
||||
)
|
||||
completed_queue = UpstreamAwareQueue()
|
||||
completed_docs = UpstreamAwareQueue()
|
||||
|
||||
# Track active documents for proper termination
|
||||
doc_tracker = {"active_docs": 0, "extraction_done": False}
|
||||
@ -254,7 +288,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
async def _extract_documents_wrapper(
|
||||
self,
|
||||
input_docs: AsyncIterable[InputDocument],
|
||||
page_queue: asyncio.Queue,
|
||||
page_queue: UpstreamAwareQueue,
|
||||
track_document_start,
|
||||
exception_event: asyncio.Event,
|
||||
doc_tracker: Dict[str, Any],
|
||||
@ -279,8 +313,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
|
||||
async def _process_pages_wrapper(
|
||||
self,
|
||||
page_queue: asyncio.Queue,
|
||||
completed_queue: asyncio.Queue,
|
||||
page_queue: UpstreamAwareQueue,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
exception_event: asyncio.Event,
|
||||
):
|
||||
"""Wrapper for page processing with exception handling"""
|
||||
@ -295,8 +329,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
|
||||
async def _aggregate_results_wrapper(
|
||||
self,
|
||||
completed_queue: asyncio.Queue,
|
||||
completed_docs: asyncio.Queue,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
completed_docs: UpstreamAwareQueue,
|
||||
track_document_complete,
|
||||
exception_event: asyncio.Event,
|
||||
):
|
||||
@ -313,7 +347,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
raise
|
||||
|
||||
async def _yield_results(
|
||||
self, completed_docs: asyncio.Queue, exception_event: asyncio.Event
|
||||
self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event
|
||||
):
|
||||
"""Yield results as they complete"""
|
||||
while True:
|
||||
@ -334,7 +368,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
async def _extract_documents_safe(
|
||||
self,
|
||||
input_docs: AsyncIterable[InputDocument],
|
||||
page_queue: asyncio.Queue,
|
||||
page_queue: UpstreamAwareQueue,
|
||||
track_document_start,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
@ -401,17 +435,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
|
||||
async def _process_pages_safe(
|
||||
self,
|
||||
page_queue: asyncio.Queue,
|
||||
completed_queue: asyncio.Queue,
|
||||
page_queue: UpstreamAwareQueue,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Process pages through model pipeline with proper termination"""
|
||||
# Process batches through each model stage
|
||||
preprocessing_queue = asyncio.Queue()
|
||||
ocr_queue = asyncio.Queue()
|
||||
layout_queue = asyncio.Queue()
|
||||
table_queue = asyncio.Queue()
|
||||
assemble_queue = asyncio.Queue()
|
||||
preprocessing_queue = UpstreamAwareQueue()
|
||||
ocr_queue = UpstreamAwareQueue()
|
||||
layout_queue = UpstreamAwareQueue()
|
||||
table_queue = UpstreamAwareQueue()
|
||||
assemble_queue = UpstreamAwareQueue()
|
||||
|
||||
# Start processing stages using TaskGroup
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
@ -489,8 +523,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
|
||||
async def _batch_process_stage_safe(
|
||||
self,
|
||||
input_queue: asyncio.Queue,
|
||||
output_queue: asyncio.Queue,
|
||||
input_queue: UpstreamAwareQueue,
|
||||
output_queue: UpstreamAwareQueue,
|
||||
process_func,
|
||||
batch_size: int,
|
||||
timeout: float,
|
||||
@ -523,9 +557,19 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
batch.pages.append(item[0])
|
||||
batch.conv_results.append(item[1])
|
||||
|
||||
# Try to fill batch up to batch_size
|
||||
# Try to fill batch up to batch_size or until upstream exhausted
|
||||
while len(batch.pages) < batch_size:
|
||||
# Check if upstream is exhausted and we should process immediately
|
||||
if input_queue.should_stop_waiting_for_batch():
|
||||
break
|
||||
|
||||
# Calculate remaining time for batch timeout
|
||||
remaining_time = timeout - (time.time() - batch.start_time)
|
||||
|
||||
# If upstream is complete, use minimal timeout to drain quickly
|
||||
if input_queue.upstream_complete.is_set():
|
||||
remaining_time = min(remaining_time, 0.05)
|
||||
|
||||
if remaining_time <= 0:
|
||||
break
|
||||
|
||||
@ -679,8 +723,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
|
||||
async def _finalize_pages_safe(
|
||||
self,
|
||||
input_queue: asyncio.Queue,
|
||||
output_queue: asyncio.Queue,
|
||||
input_queue: UpstreamAwareQueue,
|
||||
output_queue: UpstreamAwareQueue,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Finalize pages and track completion with proper termination"""
|
||||
@ -715,8 +759,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||
|
||||
async def _aggregate_results_safe(
|
||||
self,
|
||||
completed_queue: asyncio.Queue,
|
||||
completed_docs: asyncio.Queue,
|
||||
completed_queue: UpstreamAwareQueue,
|
||||
completed_docs: UpstreamAwareQueue,
|
||||
track_document_complete,
|
||||
exception_event: asyncio.Event,
|
||||
) -> None:
|
||||
|
@ -41,7 +41,7 @@ authors = [
|
||||
{ name = "Panos Vagenas", email = "pva@zurich.ibm.com" },
|
||||
{ name = "Peter Staar", email = "taa@zurich.ibm.com" },
|
||||
]
|
||||
requires-python = '>=3.9,<4.0'
|
||||
requires-python = '>=3.11,<4.0'
|
||||
dependencies = [
|
||||
'pydantic (>=2.0.0,<3.0.0)',
|
||||
'docling-core[chunking] (>=2.42.0,<3.0.0)',
|
||||
|
Loading…
Reference in New Issue
Block a user