finalised the first working ASR pipeline with Whisper

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-06-18 06:50:10 +02:00
parent ed10d09936
commit 43239ff712
4 changed files with 52 additions and 46 deletions

View File

@ -31,11 +31,11 @@ from docling.backend.pdf_backend import PdfDocumentBackend
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import ( from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_BASE, WHISPER_BASE,
WHISPER_LARGE, WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO, WHISPER_TURBO,
AsrModelType, AsrModelType,
) )

View File

@ -10,14 +10,13 @@ from docling.datamodel.pipeline_options_asr_model import (
# AsrResponseFormat, # AsrResponseFormat,
# ApiAsrOptions, # ApiAsrOptions,
InferenceAsrFramework, InferenceAsrFramework,
InlineAsrOptions, InlineAsrNativeWhisperOptions,
TransformersModelType, TransformersModelType,
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
# SmolDocling WHISPER_TINY = InlineAsrNativeWhisperOptions(
WHISPER_TINY = InlineAsrOptions(
repo_id="tiny", repo_id="tiny",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -28,8 +27,7 @@ WHISPER_TINY = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_SMALL = InlineAsrNativeWhisperOptions(
WHISPER_SMALL = InlineAsrOptions(
repo_id="small", repo_id="small",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -40,7 +38,7 @@ WHISPER_SMALL = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_MEDIUM = InlineAsrOptions( WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
repo_id="medium", repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -51,7 +49,7 @@ WHISPER_MEDIUM = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_BASE = InlineAsrOptions( WHISPER_BASE = InlineAsrNativeWhisperOptions(
repo_id="base", repo_id="base",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -62,7 +60,7 @@ WHISPER_BASE = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_LARGE = InlineAsrOptions( WHISPER_LARGE = InlineAsrNativeWhisperOptions(
repo_id="large", repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,
@ -73,7 +71,7 @@ WHISPER_LARGE = InlineAsrOptions(
max_time_chunk=30.0, max_time_chunk=30.0,
) )
WHISPER_TURBO = InlineAsrOptions( WHISPER_TURBO = InlineAsrNativeWhisperOptions(
repo_id="turbo", repo_id="turbo",
inference_framework=InferenceAsrFramework.WHISPER, inference_framework=InferenceAsrFramework.WHISPER,
verbose=True, verbose=True,

View File

@ -27,11 +27,8 @@ class InlineAsrOptions(BaseAsrOptions):
repo_id: str repo_id: str
inference_framework: InferenceAsrFramework
verbose: bool = False verbose: bool = False
timestamps: bool = True timestamps: bool = True
word_timestamps: bool = True
temperature: float = 0.0 temperature: float = 0.0
max_new_tokens: int = 256 max_new_tokens: int = 256
@ -44,25 +41,17 @@ class InlineAsrOptions(BaseAsrOptions):
AcceleratorDevice.MPS, AcceleratorDevice.MPS,
] ]
"""
repo_id: str
trust_remote_code: bool = False
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: AsrResponseFormat
temperature: float = 0.0
stop_strings: List[str] = []
extra_generation_config: Dict[str, Any] = {}
use_kv_cache: bool = True
max_new_tokens: int = 4096
"""
@property @property
def repo_cache_folder(self) -> str: def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--") return self.repo_id.replace("/", "--")
class InlineAsrNativeWhisperOptions(InlineAsrOptions):
inference_framework: InferenceAsrFramework = InferenceAsrFramework.WHISPER
language: str = "en"
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
]
word_timestamps: bool = True

View File

@ -28,6 +28,7 @@ from docling.datamodel.pipeline_options import (
AsrPipelineOptions, AsrPipelineOptions,
) )
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
InlineAsrNativeWhisperOptions,
# AsrResponseFormat, # AsrResponseFormat,
InlineAsrOptions, InlineAsrOptions,
) )
@ -36,6 +37,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
) )
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.pipeline.base_pipeline import BasePipeline from docling.pipeline.base_pipeline import BasePipeline
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -96,8 +98,7 @@ class _NativeWhisperModel:
enabled: bool, enabled: bool,
artifacts_path: Optional[Path], artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions, accelerator_options: AcceleratorOptions,
asr_options: InlineAsrOptions, asr_options: InlineAsrNativeWhisperOptions,
# model_name: str = "medium",
): ):
""" """
Transcriber using native Whisper. Transcriber using native Whisper.
@ -118,10 +119,25 @@ class _NativeWhisperModel:
self.max_tokens = asr_options.max_new_tokens self.max_tokens = asr_options.max_new_tokens
self.temperature = asr_options.temperature self.temperature = asr_options.temperature
self.device = decide_device(
accelerator_options.device,
supported_devices=asr_options.supported_devices,
)
_log.info(f"Available device for Whisper: {self.device}")
self.model_name = asr_options.repo_id self.model_name = asr_options.repo_id
_log.info(f"loading _NativeWhisperModel({self.model_name})") _log.info(f"loading _NativeWhisperModel({self.model_name})")
self.model = whisper.load_model(self.model_name) if artifacts_path is not None:
_log.info(f"loading {self.model_name} from {artifacts_path}")
self.model = whisper.load_model(
name=self.model_name,
device=self.device,
download_root=str(artifacts_path),
)
else:
self.model = whisper.load_model(
name=self.model_name, device=self.device
)
self.verbose = asr_options.verbose self.verbose = asr_options.verbose
self.timestamps = asr_options.timestamps self.timestamps = asr_options.timestamps
@ -189,15 +205,18 @@ class AsrPipeline(BasePipeline):
"When defined, it must point to a folder containing all models required by the pipeline." "When defined, it must point to a folder containing all models required by the pipeline."
) )
if isinstance(self.pipeline_options.asr_options, InlineAsrOptions): if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
asr_options: InlineAsrNativeWhisperOptions = (
self.pipeline_options.asr_options
)
self._model = _NativeWhisperModel( self._model = _NativeWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense. enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
asr_options=pipeline_options.asr_options asr_options=asr_options,
) )
else: else:
_log.error("") _log.error(f"No model support for {self.pipeline_options.asr_options}")
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
status = ConversionStatus.SUCCESS status = ConversionStatus.SUCCESS