finalised the first working ASR pipeline with Whisper

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-06-18 06:50:10 +02:00
parent ed10d09936
commit 43239ff712
4 changed files with 52 additions and 46 deletions

View File

@ -31,11 +31,11 @@ from docling.backend.pdf_backend import PdfDocumentBackend
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import ( from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_BASE, WHISPER_BASE,
WHISPER_LARGE, WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO, WHISPER_TURBO,
AsrModelType, AsrModelType,
) )
@ -653,15 +653,15 @@ def convert( # noqa: C901
elif asr_model == AsrModelType.WHISPER_BASE: elif asr_model == AsrModelType.WHISPER_BASE:
pipeline_options.asr_options = WHISPER_BASE pipeline_options.asr_options = WHISPER_BASE
elif asr_model == AsrModelType.WHISPER_LARGE: elif asr_model == AsrModelType.WHISPER_LARGE:
pipeline_options.asr_options = WHISPER_LARGE pipeline_options.asr_options = WHISPER_LARGE
elif asr_model == AsrModelType.WHISPER_TURBO: elif asr_model == AsrModelType.WHISPER_TURBO:
pipeline_options.asr_options = WHISPER_TURBO pipeline_options.asr_options = WHISPER_TURBO
else: else:
_log.error(f"{asr_model} is not known") _log.error(f"{asr_model} is not known")
raise ValueError(f"{asr_model} is not known") raise ValueError(f"{asr_model} is not known")
_log.info(f"pipeline_options: {pipeline_options}") _log.info(f"pipeline_options: {pipeline_options}")
audio_format_option = AudioFormatOption( audio_format_option = AudioFormatOption(
pipeline_cls=AsrPipeline, pipeline_cls=AsrPipeline,
pipeline_options=pipeline_options, pipeline_options=pipeline_options,

View File

@ -10,14 +10,13 @@ from docling.datamodel.pipeline_options_asr_model import (
# AsrResponseFormat, # AsrResponseFormat,
# ApiAsrOptions, # ApiAsrOptions,
InferenceAsrFramework, InferenceAsrFramework,
InlineAsrOptions, InlineAsrNativeWhisperOptions,
TransformersModelType, TransformersModelType,
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
# SmolDocling WHISPER_TINY = InlineAsrNativeWhisperOptions(
WHISPER_TINY = InlineAsrOptions(
repo_id="tiny", repo_id="tiny",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -28,8 +27,7 @@ WHISPER_TINY = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_SMALL = InlineAsrNativeWhisperOptions(
WHISPER_SMALL = InlineAsrOptions(
repo_id="small", repo_id="small",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -40,7 +38,7 @@ WHISPER_SMALL = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_MEDIUM = InlineAsrOptions( WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
repo_id="medium", repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -51,7 +49,7 @@ WHISPER_MEDIUM = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_BASE = InlineAsrOptions( WHISPER_BASE = InlineAsrNativeWhisperOptions(
repo_id="base", repo_id="base",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -62,7 +60,7 @@ WHISPER_BASE = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_LARGE = InlineAsrOptions( WHISPER_LARGE = InlineAsrNativeWhisperOptions(
repo_id="large", repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -73,7 +71,7 @@ WHISPER_LARGE = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_TURBO = InlineAsrOptions( WHISPER_TURBO = InlineAsrNativeWhisperOptions(
repo_id="turbo", repo_id="turbo",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,

View File

@ -27,11 +27,8 @@ class InlineAsrOptions(BaseAsrOptions):
repo_id: str repo_id: str
inference_framework: InferenceAsrFramework
verbose: bool = False verbose: bool = False
timestamps: bool = True timestamps: bool = True
word_timestamps: bool = True
temperature: float = 0.0 temperature: float = 0.0
max_new_tokens: int = 256 max_new_tokens: int = 256
@ -44,25 +41,17 @@ class InlineAsrOptions(BaseAsrOptions):
AcceleratorDevice.MPS, AcceleratorDevice.MPS,
] ]
"""
repo_id: str
trust_remote_code: bool = False
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: AsrResponseFormat
temperature: float = 0.0
stop_strings: List[str] = []
extra_generation_config: Dict[str, Any] = {}
use_kv_cache: bool = True
max_new_tokens: int = 4096
"""
@property @property
def repo_cache_folder(self) -> str: def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--") return self.repo_id.replace("/", "--")
class InlineAsrNativeWhisperOptions(InlineAsrOptions):
inference_framework: InferenceAsrFramework = InferenceAsrFramework.WHISPER
language: str = "en"
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
]
word_timestamps: bool = True

View File

@ -28,6 +28,7 @@ from docling.datamodel.pipeline_options import (
AsrPipelineOptions, AsrPipelineOptions,
) )
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
InlineAsrNativeWhisperOptions,
# AsrResponseFormat, # AsrResponseFormat,
InlineAsrOptions, InlineAsrOptions,
) )
@ -36,6 +37,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
) )
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.pipeline.base_pipeline import BasePipeline from docling.pipeline.base_pipeline import BasePipeline
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -96,17 +98,16 @@ class _NativeWhisperModel:
enabled: bool, enabled: bool,
artifacts_path: Optional[Path], artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions, accelerator_options: AcceleratorOptions,
asr_options: InlineAsrOptions, asr_options: InlineAsrNativeWhisperOptions,
# model_name: str = "medium",
): ):
""" """
Transcriber using native Whisper. Transcriber using native Whisper.
""" """
self.enabled = enabled self.enabled = enabled
_log.info(f"artifacts-path: {artifacts_path}") _log.info(f"artifacts-path: {artifacts_path}")
_log.info(f"accelerator_options: {accelerator_options}") _log.info(f"accelerator_options: {accelerator_options}")
if self.enabled: if self.enabled:
try: try:
import whisper # type: ignore import whisper # type: ignore
@ -118,10 +119,25 @@ class _NativeWhisperModel:
self.max_tokens = asr_options.max_new_tokens self.max_tokens = asr_options.max_new_tokens
self.temperature = asr_options.temperature self.temperature = asr_options.temperature
self.device = decide_device(
accelerator_options.device,
supported_devices=asr_options.supported_devices,
)
_log.info(f"Available device for Whisper: {self.device}")
self.model_name = asr_options.repo_id self.model_name = asr_options.repo_id
_log.info(f"loading _NativeWhisperModel({self.model_name})") _log.info(f"loading _NativeWhisperModel({self.model_name})")
self.model = whisper.load_model(self.model_name) if artifacts_path is not None:
_log.info(f"loading {self.model_name} from {artifacts_path}")
self.model = whisper.load_model(
name=self.model_name,
device=self.device,
download_root=str(artifacts_path),
)
else:
self.model = whisper.load_model(
name=self.model_name, device=self.device
)
self.verbose = asr_options.verbose self.verbose = asr_options.verbose
self.timestamps = asr_options.timestamps self.timestamps = asr_options.timestamps
@ -189,15 +205,18 @@ class AsrPipeline(BasePipeline):
"When defined, it must point to a folder containing all models required by the pipeline." "When defined, it must point to a folder containing all models required by the pipeline."
) )
if isinstance(self.pipeline_options.asr_options, InlineAsrOptions): if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
asr_options: InlineAsrNativeWhisperOptions = (
self.pipeline_options.asr_options
)
self._model = _NativeWhisperModel( self._model = _NativeWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense. enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
asr_options=pipeline_options.asr_options asr_options=asr_options,
) )
else: else:
_log.error("") _log.error(f"No model support for {self.pipeline_options.asr_options}")
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
status = ConversionStatus.SUCCESS status = ConversionStatus.SUCCESS