finalised the first working ASR pipeline with Whisper

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-06-18 06:50:10 +02:00
parent ed10d09936
commit 43239ff712
4 changed files with 52 additions and 46 deletions

View File

@ -31,11 +31,11 @@ from docling.backend.pdf_backend import PdfDocumentBackend
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_BASE,
WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO,
AsrModelType,
)

View File

@ -10,14 +10,13 @@ from docling.datamodel.pipeline_options_asr_model import (
# AsrResponseFormat,
# ApiAsrOptions,
InferenceAsrFramework,
InlineAsrOptions,
InlineAsrNativeWhisperOptions,
TransformersModelType,
)
_log = logging.getLogger(__name__)
# SmolDocling
WHISPER_TINY = InlineAsrOptions(
WHISPER_TINY = InlineAsrNativeWhisperOptions(
repo_id="tiny",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -28,8 +27,7 @@ WHISPER_TINY = InlineAsrOptions(
max_time_chunk=30.0,
)
WHISPER_SMALL = InlineAsrOptions(
WHISPER_SMALL = InlineAsrNativeWhisperOptions(
repo_id="small",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -40,7 +38,7 @@ WHISPER_SMALL = InlineAsrOptions(
max_time_chunk=30.0,
)
WHISPER_MEDIUM = InlineAsrOptions(
WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -51,7 +49,7 @@ WHISPER_MEDIUM = InlineAsrOptions(
max_time_chunk=30.0,
)
WHISPER_BASE = InlineAsrOptions(
WHISPER_BASE = InlineAsrNativeWhisperOptions(
repo_id="base",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -62,7 +60,7 @@ WHISPER_BASE = InlineAsrOptions(
max_time_chunk=30.0,
)
WHISPER_LARGE = InlineAsrOptions(
WHISPER_LARGE = InlineAsrNativeWhisperOptions(
repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -73,7 +71,7 @@ WHISPER_LARGE = InlineAsrOptions(
max_time_chunk=30.0,
)
WHISPER_TURBO = InlineAsrOptions(
WHISPER_TURBO = InlineAsrNativeWhisperOptions(
repo_id="turbo",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,

View File

@ -27,11 +27,8 @@ class InlineAsrOptions(BaseAsrOptions):
repo_id: str
inference_framework: InferenceAsrFramework
verbose: bool = False
timestamps: bool = True
word_timestamps: bool = True
temperature: float = 0.0
max_new_tokens: int = 256
@ -44,25 +41,17 @@ class InlineAsrOptions(BaseAsrOptions):
AcceleratorDevice.MPS,
]
"""
repo_id: str
trust_remote_code: bool = False
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: AsrResponseFormat
temperature: float = 0.0
stop_strings: List[str] = []
extra_generation_config: Dict[str, Any] = {}
use_kv_cache: bool = True
max_new_tokens: int = 4096
"""
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
class InlineAsrNativeWhisperOptions(InlineAsrOptions):
inference_framework: InferenceAsrFramework = InferenceAsrFramework.WHISPER
language: str = "en"
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
]
word_timestamps: bool = True

View File

@ -28,6 +28,7 @@ from docling.datamodel.pipeline_options import (
AsrPipelineOptions,
)
from docling.datamodel.pipeline_options_asr_model import (
InlineAsrNativeWhisperOptions,
# AsrResponseFormat,
InlineAsrOptions,
)
@ -36,6 +37,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
)
from docling.datamodel.settings import settings
from docling.pipeline.base_pipeline import BasePipeline
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__)
@ -96,8 +98,7 @@ class _NativeWhisperModel:
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
asr_options: InlineAsrOptions,
# model_name: str = "medium",
asr_options: InlineAsrNativeWhisperOptions,
):
"""
Transcriber using native Whisper.
@ -118,10 +119,25 @@ class _NativeWhisperModel:
self.max_tokens = asr_options.max_new_tokens
self.temperature = asr_options.temperature
self.device = decide_device(
accelerator_options.device,
supported_devices=asr_options.supported_devices,
)
_log.info(f"Available device for Whisper: {self.device}")
self.model_name = asr_options.repo_id
_log.info(f"loading _NativeWhisperModel({self.model_name})")
self.model = whisper.load_model(self.model_name)
if artifacts_path is not None:
_log.info(f"loading {self.model_name} from {artifacts_path}")
self.model = whisper.load_model(
name=self.model_name,
device=self.device,
download_root=str(artifacts_path),
)
else:
self.model = whisper.load_model(
name=self.model_name, device=self.device
)
self.verbose = asr_options.verbose
self.timestamps = asr_options.timestamps
@ -189,15 +205,18 @@ class AsrPipeline(BasePipeline):
"When defined, it must point to a folder containing all models required by the pipeline."
)
if isinstance(self.pipeline_options.asr_options, InlineAsrOptions):
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
asr_options: InlineAsrNativeWhisperOptions = (
self.pipeline_options.asr_options
)
self._model = _NativeWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
asr_options=pipeline_options.asr_options
asr_options=asr_options,
)
else:
_log.error("")
_log.error(f"No model support for {self.pipeline_options.asr_options}")
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
status = ConversionStatus.SUCCESS