diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index f7e82b5f..f11e08d8 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.document import ConversionResult @@ -36,12 +37,15 @@ class ApiVlmModel(BasePageModel): } def __call__( - self, conv_res: ConversionResult, page_batch: Iterable[Page] + self, + conv_res: ConversionResult, + page_batch: Iterable[Page], + concurrency: int = 1, ) -> Iterable[Page]: - for page in page_batch: + def _vlm_request(page): assert page._backend is not None if not page._backend.is_valid(): - yield page + return page else: with TimeRecorder(conv_res, "vlm"): assert page.size is not None @@ -63,4 +67,8 @@ class ApiVlmModel(BasePageModel): page.predictions.vlm_response = VlmPrediction(text=page_tags) - yield page + return page + + with ThreadPoolExecutor(max_workers=concurrency) as executor: + for result in executor.map(_vlm_request, page_batch): + yield from result diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index 44bb5e21..2c19759d 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Optional, Type, Union @@ -45,11 +46,13 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel): "pipeline_options.enable_remote_services=True." ) - def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + def _annotate_images( + self, images: Iterable[Image.Image], concurrency: int = 1 + ) -> Iterable[str]: # Note: technically we could make a batch request here, # but not all APIs will allow for it. For example, vllm won't allow more than 1. - for image in images: - yield api_image_request( + def _api_request(image): + return api_image_request( image=image, prompt=self.options.prompt, url=self.options.url, @@ -57,3 +60,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel): headers=self.options.headers, **self.options.params, ) + + with ThreadPoolExecutor(max_workers=concurrency) as executor: + for result in executor.map(_api_request, images): + yield from result