mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
UpstreamAwareQueue
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
f56de726f3
commit
ef25d03bc8
@ -54,6 +54,38 @@ class QueueTerminator:
|
|||||||
error: Optional[Exception] = None
|
error: Optional[Exception] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UpstreamAwareQueue(asyncio.Queue):
|
||||||
|
"""Queue that tracks upstream completion to optimize batch processing"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.upstream_complete = asyncio.Event()
|
||||||
|
self.items_in_flight = 0
|
||||||
|
|
||||||
|
async def put(self, item):
|
||||||
|
"""Track items and upstream completion signals"""
|
||||||
|
if isinstance(item, QueueTerminator):
|
||||||
|
self.upstream_complete.set()
|
||||||
|
else:
|
||||||
|
self.items_in_flight += 1
|
||||||
|
await super().put(item)
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
"""Track when items are consumed"""
|
||||||
|
item = await super().get()
|
||||||
|
if not isinstance(item, QueueTerminator):
|
||||||
|
self.items_in_flight -= 1
|
||||||
|
return item
|
||||||
|
|
||||||
|
def should_stop_waiting_for_batch(self) -> bool:
|
||||||
|
"""Determine if we should stop waiting and process current batch"""
|
||||||
|
return (
|
||||||
|
self.upstream_complete.is_set()
|
||||||
|
and self.items_in_flight == 0
|
||||||
|
and self.empty()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncStandardPdfPipeline(AsyncPipeline):
|
class AsyncStandardPdfPipeline(AsyncPipeline):
|
||||||
"""Async pipeline implementation with cross-document batching using structured concurrency"""
|
"""Async pipeline implementation with cross-document batching using structured concurrency"""
|
||||||
|
|
||||||
@ -187,9 +219,11 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
) -> AsyncIterable[ConversionResult]:
|
) -> AsyncIterable[ConversionResult]:
|
||||||
"""Main async processing with structured concurrency and proper exception handling"""
|
"""Main async processing with structured concurrency and proper exception handling"""
|
||||||
# Create queues for pipeline stages
|
# Create queues for pipeline stages
|
||||||
page_queue = asyncio.Queue(maxsize=self.pipeline_options.extraction_queue_size)
|
page_queue = UpstreamAwareQueue(
|
||||||
completed_queue = asyncio.Queue()
|
maxsize=self.pipeline_options.extraction_queue_size
|
||||||
completed_docs = asyncio.Queue()
|
)
|
||||||
|
completed_queue = UpstreamAwareQueue()
|
||||||
|
completed_docs = UpstreamAwareQueue()
|
||||||
|
|
||||||
# Track active documents for proper termination
|
# Track active documents for proper termination
|
||||||
doc_tracker = {"active_docs": 0, "extraction_done": False}
|
doc_tracker = {"active_docs": 0, "extraction_done": False}
|
||||||
@ -254,7 +288,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
async def _extract_documents_wrapper(
|
async def _extract_documents_wrapper(
|
||||||
self,
|
self,
|
||||||
input_docs: AsyncIterable[InputDocument],
|
input_docs: AsyncIterable[InputDocument],
|
||||||
page_queue: asyncio.Queue,
|
page_queue: UpstreamAwareQueue,
|
||||||
track_document_start,
|
track_document_start,
|
||||||
exception_event: asyncio.Event,
|
exception_event: asyncio.Event,
|
||||||
doc_tracker: Dict[str, Any],
|
doc_tracker: Dict[str, Any],
|
||||||
@ -279,8 +313,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
|
|
||||||
async def _process_pages_wrapper(
|
async def _process_pages_wrapper(
|
||||||
self,
|
self,
|
||||||
page_queue: asyncio.Queue,
|
page_queue: UpstreamAwareQueue,
|
||||||
completed_queue: asyncio.Queue,
|
completed_queue: UpstreamAwareQueue,
|
||||||
exception_event: asyncio.Event,
|
exception_event: asyncio.Event,
|
||||||
):
|
):
|
||||||
"""Wrapper for page processing with exception handling"""
|
"""Wrapper for page processing with exception handling"""
|
||||||
@ -295,8 +329,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
|
|
||||||
async def _aggregate_results_wrapper(
|
async def _aggregate_results_wrapper(
|
||||||
self,
|
self,
|
||||||
completed_queue: asyncio.Queue,
|
completed_queue: UpstreamAwareQueue,
|
||||||
completed_docs: asyncio.Queue,
|
completed_docs: UpstreamAwareQueue,
|
||||||
track_document_complete,
|
track_document_complete,
|
||||||
exception_event: asyncio.Event,
|
exception_event: asyncio.Event,
|
||||||
):
|
):
|
||||||
@ -313,7 +347,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _yield_results(
|
async def _yield_results(
|
||||||
self, completed_docs: asyncio.Queue, exception_event: asyncio.Event
|
self, completed_docs: UpstreamAwareQueue, exception_event: asyncio.Event
|
||||||
):
|
):
|
||||||
"""Yield results as they complete"""
|
"""Yield results as they complete"""
|
||||||
while True:
|
while True:
|
||||||
@ -334,7 +368,7 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
async def _extract_documents_safe(
|
async def _extract_documents_safe(
|
||||||
self,
|
self,
|
||||||
input_docs: AsyncIterable[InputDocument],
|
input_docs: AsyncIterable[InputDocument],
|
||||||
page_queue: asyncio.Queue,
|
page_queue: UpstreamAwareQueue,
|
||||||
track_document_start,
|
track_document_start,
|
||||||
exception_event: asyncio.Event,
|
exception_event: asyncio.Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -401,17 +435,17 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
|
|
||||||
async def _process_pages_safe(
|
async def _process_pages_safe(
|
||||||
self,
|
self,
|
||||||
page_queue: asyncio.Queue,
|
page_queue: UpstreamAwareQueue,
|
||||||
completed_queue: asyncio.Queue,
|
completed_queue: UpstreamAwareQueue,
|
||||||
exception_event: asyncio.Event,
|
exception_event: asyncio.Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process pages through model pipeline with proper termination"""
|
"""Process pages through model pipeline with proper termination"""
|
||||||
# Process batches through each model stage
|
# Process batches through each model stage
|
||||||
preprocessing_queue = asyncio.Queue()
|
preprocessing_queue = UpstreamAwareQueue()
|
||||||
ocr_queue = asyncio.Queue()
|
ocr_queue = UpstreamAwareQueue()
|
||||||
layout_queue = asyncio.Queue()
|
layout_queue = UpstreamAwareQueue()
|
||||||
table_queue = asyncio.Queue()
|
table_queue = UpstreamAwareQueue()
|
||||||
assemble_queue = asyncio.Queue()
|
assemble_queue = UpstreamAwareQueue()
|
||||||
|
|
||||||
# Start processing stages using TaskGroup
|
# Start processing stages using TaskGroup
|
||||||
async with asyncio.TaskGroup() as tg:
|
async with asyncio.TaskGroup() as tg:
|
||||||
@ -489,8 +523,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
|
|
||||||
async def _batch_process_stage_safe(
|
async def _batch_process_stage_safe(
|
||||||
self,
|
self,
|
||||||
input_queue: asyncio.Queue,
|
input_queue: UpstreamAwareQueue,
|
||||||
output_queue: asyncio.Queue,
|
output_queue: UpstreamAwareQueue,
|
||||||
process_func,
|
process_func,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
@ -523,9 +557,19 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
batch.pages.append(item[0])
|
batch.pages.append(item[0])
|
||||||
batch.conv_results.append(item[1])
|
batch.conv_results.append(item[1])
|
||||||
|
|
||||||
# Try to fill batch up to batch_size
|
# Try to fill batch up to batch_size or until upstream exhausted
|
||||||
while len(batch.pages) < batch_size:
|
while len(batch.pages) < batch_size:
|
||||||
|
# Check if upstream is exhausted and we should process immediately
|
||||||
|
if input_queue.should_stop_waiting_for_batch():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate remaining time for batch timeout
|
||||||
remaining_time = timeout - (time.time() - batch.start_time)
|
remaining_time = timeout - (time.time() - batch.start_time)
|
||||||
|
|
||||||
|
# If upstream is complete, use minimal timeout to drain quickly
|
||||||
|
if input_queue.upstream_complete.is_set():
|
||||||
|
remaining_time = min(remaining_time, 0.05)
|
||||||
|
|
||||||
if remaining_time <= 0:
|
if remaining_time <= 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -679,8 +723,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
|
|
||||||
async def _finalize_pages_safe(
|
async def _finalize_pages_safe(
|
||||||
self,
|
self,
|
||||||
input_queue: asyncio.Queue,
|
input_queue: UpstreamAwareQueue,
|
||||||
output_queue: asyncio.Queue,
|
output_queue: UpstreamAwareQueue,
|
||||||
exception_event: asyncio.Event,
|
exception_event: asyncio.Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Finalize pages and track completion with proper termination"""
|
"""Finalize pages and track completion with proper termination"""
|
||||||
@ -715,8 +759,8 @@ class AsyncStandardPdfPipeline(AsyncPipeline):
|
|||||||
|
|
||||||
async def _aggregate_results_safe(
|
async def _aggregate_results_safe(
|
||||||
self,
|
self,
|
||||||
completed_queue: asyncio.Queue,
|
completed_queue: UpstreamAwareQueue,
|
||||||
completed_docs: asyncio.Queue,
|
completed_docs: UpstreamAwareQueue,
|
||||||
track_document_complete,
|
track_document_complete,
|
||||||
exception_event: asyncio.Event,
|
exception_event: asyncio.Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -41,7 +41,7 @@ authors = [
|
|||||||
{ name = "Panos Vagenas", email = "pva@zurich.ibm.com" },
|
{ name = "Panos Vagenas", email = "pva@zurich.ibm.com" },
|
||||||
{ name = "Peter Staar", email = "taa@zurich.ibm.com" },
|
{ name = "Peter Staar", email = "taa@zurich.ibm.com" },
|
||||||
]
|
]
|
||||||
requires-python = '>=3.9,<4.0'
|
requires-python = '>=3.11,<4.0'
|
||||||
dependencies = [
|
dependencies = [
|
||||||
'pydantic (>=2.0.0,<3.0.0)',
|
'pydantic (>=2.0.0,<3.0.0)',
|
||||||
'docling-core[chunking] (>=2.42.0,<3.0.0)',
|
'docling-core[chunking] (>=2.42.0,<3.0.0)',
|
||||||
|
Loading…
Reference in New Issue
Block a user