From 43239ff7121349594efdc5049512984f56db0e5f Mon Sep 17 00:00:00 2001 From: Peter Staar Date: Wed, 18 Jun 2025 06:50:10 +0200 Subject: [PATCH] finalised the first working ASR pipeline with Whisper Signed-off-by: Peter Staar --- docling/cli/main.py | 12 +++--- docling/datamodel/asr_model_specs.py | 16 ++++---- .../datamodel/pipeline_options_asr_model.py | 33 ++++++----------- docling/pipeline/asr_pipeline.py | 37 ++++++++++++++----- 4 files changed, 52 insertions(+), 46 deletions(-) diff --git a/docling/cli/main.py b/docling/cli/main.py index 24f276ae..34b6a14d 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -31,11 +31,11 @@ from docling.backend.pdf_backend import PdfDocumentBackend from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.asr_model_specs import ( - WHISPER_TINY, - WHISPER_SMALL, - WHISPER_MEDIUM, WHISPER_BASE, WHISPER_LARGE, + WHISPER_MEDIUM, + WHISPER_SMALL, + WHISPER_TINY, WHISPER_TURBO, AsrModelType, ) @@ -653,15 +653,15 @@ def convert( # noqa: C901 elif asr_model == AsrModelType.WHISPER_BASE: pipeline_options.asr_options = WHISPER_BASE elif asr_model == AsrModelType.WHISPER_LARGE: - pipeline_options.asr_options = WHISPER_LARGE + pipeline_options.asr_options = WHISPER_LARGE elif asr_model == AsrModelType.WHISPER_TURBO: - pipeline_options.asr_options = WHISPER_TURBO + pipeline_options.asr_options = WHISPER_TURBO else: _log.error(f"{asr_model} is not known") raise ValueError(f"{asr_model} is not known") _log.info(f"pipeline_options: {pipeline_options}") - + audio_format_option = AudioFormatOption( pipeline_cls=AsrPipeline, pipeline_options=pipeline_options, diff --git a/docling/datamodel/asr_model_specs.py b/docling/datamodel/asr_model_specs.py index e7cd0c9e..95287ad2 100644 --- a/docling/datamodel/asr_model_specs.py +++ b/docling/datamodel/asr_model_specs.py @@ -10,14 +10,13 @@ from docling.datamodel.pipeline_options_asr_model import ( # AsrResponseFormat, # ApiAsrOptions, InferenceAsrFramework, - InlineAsrOptions, + InlineAsrNativeWhisperOptions, TransformersModelType, ) _log = logging.getLogger(__name__) -# SmolDocling -WHISPER_TINY = InlineAsrOptions( +WHISPER_TINY = InlineAsrNativeWhisperOptions( repo_id="tiny", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -28,8 +27,7 @@ WHISPER_TINY = InlineAsrOptions( max_time_chunk=30.0, ) - -WHISPER_SMALL = InlineAsrOptions( +WHISPER_SMALL = InlineAsrNativeWhisperOptions( repo_id="small", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -40,7 +38,7 @@ WHISPER_SMALL = InlineAsrOptions( max_time_chunk=30.0, ) -WHISPER_MEDIUM = InlineAsrOptions( +WHISPER_MEDIUM = InlineAsrNativeWhisperOptions( repo_id="medium", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -51,7 +49,7 @@ WHISPER_MEDIUM = InlineAsrOptions( max_time_chunk=30.0, ) -WHISPER_BASE = InlineAsrOptions( +WHISPER_BASE = InlineAsrNativeWhisperOptions( repo_id="base", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -62,7 +60,7 @@ WHISPER_BASE = InlineAsrOptions( max_time_chunk=30.0, ) -WHISPER_LARGE = InlineAsrOptions( +WHISPER_LARGE = InlineAsrNativeWhisperOptions( repo_id="large", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -73,7 +71,7 @@ WHISPER_LARGE = InlineAsrOptions( max_time_chunk=30.0, ) -WHISPER_TURBO = InlineAsrOptions( +WHISPER_TURBO = InlineAsrNativeWhisperOptions( repo_id="turbo", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, diff --git a/docling/datamodel/pipeline_options_asr_model.py b/docling/datamodel/pipeline_options_asr_model.py index 11cbd5b0..e892254c 100644 --- a/docling/datamodel/pipeline_options_asr_model.py +++ b/docling/datamodel/pipeline_options_asr_model.py @@ -27,11 +27,8 @@ class InlineAsrOptions(BaseAsrOptions): repo_id: str - inference_framework: InferenceAsrFramework - verbose: bool = False timestamps: bool = True - word_timestamps: bool = True temperature: float = 0.0 max_new_tokens: int = 256 @@ -44,25 +41,17 @@ class InlineAsrOptions(BaseAsrOptions): AcceleratorDevice.MPS, ] - """ - repo_id: str - trust_remote_code: bool = False - load_in_8bit: bool = True - llm_int8_threshold: float = 6.0 - quantized: bool = False - - inference_framework: InferenceFramework - transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL - response_format: AsrResponseFormat - - temperature: float = 0.0 - stop_strings: List[str] = [] - extra_generation_config: Dict[str, Any] = {} - - use_kv_cache: bool = True - max_new_tokens: int = 4096 - """ - @property def repo_cache_folder(self) -> str: return self.repo_id.replace("/", "--") + + +class InlineAsrNativeWhisperOptions(InlineAsrOptions): + inference_framework: InferenceAsrFramework = InferenceAsrFramework.WHISPER + + language: str = "en" + supported_devices: List[AcceleratorDevice] = [ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + ] + word_timestamps: bool = True diff --git a/docling/pipeline/asr_pipeline.py b/docling/pipeline/asr_pipeline.py index 6e0a106d..3c213cce 100644 --- a/docling/pipeline/asr_pipeline.py +++ b/docling/pipeline/asr_pipeline.py @@ -28,6 +28,7 @@ from docling.datamodel.pipeline_options import ( AsrPipelineOptions, ) from docling.datamodel.pipeline_options_asr_model import ( + InlineAsrNativeWhisperOptions, # AsrResponseFormat, InlineAsrOptions, ) @@ -36,6 +37,7 @@ from docling.datamodel.pipeline_options_vlm_model import ( ) from docling.datamodel.settings import settings from docling.pipeline.base_pipeline import BasePipeline +from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import ProfilingScope, TimeRecorder _log = logging.getLogger(__name__) @@ -96,17 +98,16 @@ class _NativeWhisperModel: enabled: bool, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions, - asr_options: InlineAsrOptions, - # model_name: str = "medium", + asr_options: InlineAsrNativeWhisperOptions, ): """ Transcriber using native Whisper. """ self.enabled = enabled - + _log.info(f"artifacts-path: {artifacts_path}") _log.info(f"accelerator_options: {accelerator_options}") - + if self.enabled: try: import whisper # type: ignore @@ -118,10 +119,25 @@ class _NativeWhisperModel: self.max_tokens = asr_options.max_new_tokens self.temperature = asr_options.temperature - + self.device = decide_device( + accelerator_options.device, + supported_devices=asr_options.supported_devices, + ) + _log.info(f"Available device for Whisper: {self.device}") + self.model_name = asr_options.repo_id _log.info(f"loading _NativeWhisperModel({self.model_name})") - self.model = whisper.load_model(self.model_name) + if artifacts_path is not None: + _log.info(f"loading {self.model_name} from {artifacts_path}") + self.model = whisper.load_model( + name=self.model_name, + device=self.device, + download_root=str(artifacts_path), + ) + else: + self.model = whisper.load_model( + name=self.model_name, device=self.device + ) self.verbose = asr_options.verbose self.timestamps = asr_options.timestamps @@ -189,15 +205,18 @@ class AsrPipeline(BasePipeline): "When defined, it must point to a folder containing all models required by the pipeline." ) - if isinstance(self.pipeline_options.asr_options, InlineAsrOptions): + if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions): + asr_options: InlineAsrNativeWhisperOptions = ( + self.pipeline_options.asr_options + ) self._model = _NativeWhisperModel( enabled=True, # must be always enabled for this pipeline to make sense. artifacts_path=artifacts_path, accelerator_options=pipeline_options.accelerator_options, - asr_options=pipeline_options.asr_options + asr_options=asr_options, ) else: - _log.error("") + _log.error(f"No model support for {self.pipeline_options.asr_options}") def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: status = ConversionStatus.SUCCESS