From ed10d0993639185a3582044e7dfbab36cd6bf313 Mon Sep 17 00:00:00 2001 From: Peter Staar Date: Wed, 18 Jun 2025 06:23:33 +0200 Subject: [PATCH] updating with asr_options Signed-off-by: Peter Staar --- docling/cli/main.py | 21 ++++- docling/datamodel/asr_model_specs.py | 76 ++++++++++++++++-- docling/datamodel/base_models.py | 4 +- .../datamodel/pipeline_options_asr_model.py | 34 +++++--- docling/pipeline/asr_pipeline.py | 78 ++++++++++++++----- 5 files changed, 174 insertions(+), 39 deletions(-) diff --git a/docling/cli/main.py b/docling/cli/main.py index 54d5fc67..24f276ae 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -32,6 +32,11 @@ from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.asr_model_specs import ( WHISPER_TINY, + WHISPER_SMALL, + WHISPER_MEDIUM, + WHISPER_BASE, + WHISPER_LARGE, + WHISPER_TURBO, AsrModelType, ) from docling.datamodel.base_models import ( @@ -641,10 +646,22 @@ def convert( # noqa: C901 if asr_model == AsrModelType.WHISPER_TINY: pipeline_options.asr_options = WHISPER_TINY + elif asr_model == AsrModelType.WHISPER_SMALL: + pipeline_options.asr_options = WHISPER_SMALL + elif asr_model == AsrModelType.WHISPER_MEDIUM: + pipeline_options.asr_options = WHISPER_MEDIUM + elif asr_model == AsrModelType.WHISPER_BASE: + pipeline_options.asr_options = WHISPER_BASE + elif asr_model == AsrModelType.WHISPER_LARGE: + pipeline_options.asr_options = WHISPER_LARGE + elif asr_model == AsrModelType.WHISPER_TURBO: + pipeline_options.asr_options = WHISPER_TURBO else: - _log.warning("falling back in base ASR model: WHISPER_TINY") - pipeline_options.asr_options = WHISPER_TINY + _log.error(f"{asr_model} is not known") + raise ValueError(f"{asr_model} is not known") + _log.info(f"pipeline_options: {pipeline_options}") + audio_format_option = AudioFormatOption( pipeline_cls=AsrPipeline, pipeline_options=pipeline_options, diff --git a/docling/datamodel/asr_model_specs.py b/docling/datamodel/asr_model_specs.py index 6531f44f..e7cd0c9e 100644 --- a/docling/datamodel/asr_model_specs.py +++ b/docling/datamodel/asr_model_specs.py @@ -7,9 +7,9 @@ from pydantic import ( from docling.datamodel.accelerator_options import AcceleratorDevice from docling.datamodel.pipeline_options_asr_model import ( - AsrResponseFormat, + # AsrResponseFormat, # ApiAsrOptions, - InferenceFramework, + InferenceAsrFramework, InlineAsrOptions, TransformersModelType, ) @@ -18,11 +18,77 @@ _log = logging.getLogger(__name__) # SmolDocling WHISPER_TINY = InlineAsrOptions( - repo_id="openai/whisper-tiny", - inference_framework=InferenceFramework.TRANSFORMERS, - response_format=AsrResponseFormat.WHISPER, + repo_id="tiny", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperatue=0.0, + max_new_tokens=256, + max_time_chunk=30.0, +) + + +WHISPER_SMALL = InlineAsrOptions( + repo_id="small", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperatue=0.0, + max_new_tokens=256, + max_time_chunk=30.0, +) + +WHISPER_MEDIUM = InlineAsrOptions( + repo_id="medium", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperatue=0.0, + max_new_tokens=256, + max_time_chunk=30.0, +) + +WHISPER_BASE = InlineAsrOptions( + repo_id="base", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperatue=0.0, + max_new_tokens=256, + max_time_chunk=30.0, +) + +WHISPER_LARGE = InlineAsrOptions( + repo_id="large", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperatue=0.0, + max_new_tokens=256, + max_time_chunk=30.0, +) + +WHISPER_TURBO = InlineAsrOptions( + repo_id="turbo", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperatue=0.0, + max_new_tokens=256, + max_time_chunk=30.0, ) class AsrModelType(str, Enum): WHISPER_TINY = "whisper_tiny" + WHISPER_SMALL = "whisper_small" + WHISPER_MEDIUM = "whisper_medium" + WHISPER_BASE = "whisper_base" + WHISPER_LARGE = "whisper_large" + WHISPER_TURBO = "whisper_turbo" diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 5426fb4d..a5095d22 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -49,7 +49,7 @@ class InputFormat(str, Enum): XML_USPTO = "xml_uspto" XML_JATS = "xml_jats" JSON_DOCLING = "json_docling" - AUDIO = "wav" + AUDIO = "audio" class OutputFormat(str, Enum): @@ -74,7 +74,7 @@ FormatToExtensions: Dict[InputFormat, List[str]] = { InputFormat.XLSX: ["xlsx", "xlsm"], InputFormat.XML_USPTO: ["xml", "txt"], InputFormat.JSON_DOCLING: ["json"], - InputFormat.AUDIO: ["wav"], + InputFormat.AUDIO: ["wav", "mp3"], } FormatToMimeType: Dict[InputFormat, List[str]] = { diff --git a/docling/datamodel/pipeline_options_asr_model.py b/docling/datamodel/pipeline_options_asr_model.py index 5ad161c0..11cbd5b0 100644 --- a/docling/datamodel/pipeline_options_asr_model.py +++ b/docling/datamodel/pipeline_options_asr_model.py @@ -6,7 +6,7 @@ from typing_extensions import deprecated from docling.datamodel.accelerator_options import AcceleratorDevice from docling.datamodel.pipeline_options_vlm_model import ( - InferenceFramework, + # InferenceFramework, TransformersModelType, ) @@ -16,13 +16,35 @@ class BaseAsrOptions(BaseModel): # prompt: str -class AsrResponseFormat(str, Enum): +class InferenceAsrFramework(str, Enum): + MLX = "mlx" + TRANSFORMERS = "transformers" WHISPER = "whisper" class InlineAsrOptions(BaseAsrOptions): kind: Literal["inline_model_options"] = "inline_model_options" + repo_id: str + + inference_framework: InferenceAsrFramework + + verbose: bool = False + timestamps: bool = True + word_timestamps: bool = True + + temperature: float = 0.0 + max_new_tokens: int = 256 + max_time_chunk: float = 30.0 + + torch_dtype: Optional[str] = None + supported_devices: List[AcceleratorDevice] = [ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + ] + + """ repo_id: str trust_remote_code: bool = False load_in_8bit: bool = True @@ -33,19 +55,13 @@ class InlineAsrOptions(BaseAsrOptions): transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL response_format: AsrResponseFormat - torch_dtype: Optional[str] = None - supported_devices: List[AcceleratorDevice] = [ - AcceleratorDevice.CPU, - AcceleratorDevice.CUDA, - AcceleratorDevice.MPS, - ] - temperature: float = 0.0 stop_strings: List[str] = [] extra_generation_config: Dict[str, Any] = {} use_kv_cache: bool = True max_new_tokens: int = 4096 + """ @property def repo_cache_folder(self) -> str: diff --git a/docling/pipeline/asr_pipeline.py b/docling/pipeline/asr_pipeline.py index c8c39b16..6e0a106d 100644 --- a/docling/pipeline/asr_pipeline.py +++ b/docling/pipeline/asr_pipeline.py @@ -5,18 +5,21 @@ from io import BytesIO from pathlib import Path from typing import List, Optional, Union, cast -import whisper # type: ignore - +# import whisper # type: ignore # import librosa # import numpy as np # import soundfile as sf # type: ignore from docling_core.types.doc.labels import DocItemLabel from pydantic import BaseModel, Field, validator -# from pydub import AudioSegment # type: ignore -# from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.audio_backend import AudioBackend + +# from pydub import AudioSegment # type: ignore +# from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline +from docling.datamodel.accelerator_options import ( + AcceleratorOptions, +) from docling.datamodel.base_models import ( ConversionStatus, ) @@ -25,7 +28,7 @@ from docling.datamodel.pipeline_options import ( AsrPipelineOptions, ) from docling.datamodel.pipeline_options_asr_model import ( - AsrResponseFormat, + # AsrResponseFormat, InlineAsrOptions, ) from docling.datamodel.pipeline_options_vlm_model import ( @@ -88,15 +91,41 @@ class _ConversationItem(BaseModel): class _NativeWhisperModel: - def __init__(self, model_name: str = "medium"): + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + accelerator_options: AcceleratorOptions, + asr_options: InlineAsrOptions, + # model_name: str = "medium", + ): """ Transcriber using native Whisper. """ + self.enabled = enabled + + _log.info(f"artifacts-path: {artifacts_path}") + _log.info(f"accelerator_options: {accelerator_options}") + + if self.enabled: + try: + import whisper # type: ignore + except ImportError: + raise ImportError( + "whisper is not installed. Please install it via `pip install openai-whisper`." + ) + self.asr_options = asr_options + self.max_tokens = asr_options.max_new_tokens + self.temperature = asr_options.temperature - self.model = whisper.load_model(model_name) + + self.model_name = asr_options.repo_id + _log.info(f"loading _NativeWhisperModel({self.model_name})") + self.model = whisper.load_model(self.model_name) - self.verbose = True - self.word_timestamps = True + self.verbose = asr_options.verbose + self.timestamps = asr_options.timestamps + self.word_timestamps = asr_options.word_timestamps def run(self, conv_res: ConversionResult) -> ConversionResult: audio_path: Path = Path(conv_res.input.file).resolve() @@ -126,15 +155,16 @@ class _NativeWhisperModel: item = _ConversationItem( start_time=_["start"], end_time=_["end"], text=_["text"], words=[] ) - item.words = [] - for __ in _["words"]: - item.words.append( - _ConversationWord( - start_time=__["start"], - end_time=__["end"], - text=__["word"], + if "words" in _ and self.word_timestamps: + item.words = [] + for __ in _["words"]: + item.words.append( + _ConversationWord( + start_time=__["start"], + end_time=__["end"], + text=__["word"], + ) ) - ) convo.append(item) return convo @@ -159,8 +189,15 @@ class AsrPipeline(BasePipeline): "When defined, it must point to a folder containing all models required by the pipeline." ) - # self._model = _WhisperModel() - self._model = _NativeWhisperModel() + if isinstance(self.pipeline_options.asr_options, InlineAsrOptions): + self._model = _NativeWhisperModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + asr_options=pipeline_options.asr_options + ) + else: + _log.error("") def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: status = ConversionStatus.SUCCESS @@ -171,10 +208,9 @@ class AsrPipeline(BasePipeline): return AsrPipelineOptions() def _build_document(self, conv_res: ConversionResult) -> ConversionResult: + _log.info(f"start _build_document in AsrPipeline: {conv_res.input.file}") with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT): - _log.info(f"do something: {conv_res.input.file}") self._model.run(conv_res=conv_res) - _log.info(f"finished doing something: {conv_res.input.file}") return conv_res