feat(ASR): MLX Whisper Support for Apple Silicon (#2366)

* add mlx-whisper support

* added mlx-whisper example and test. update docling cli to use MLX automatically if present.

* fix pre-commit checks and added proper type safety

* fixed linter issue

* DCO Remediation Commit for Ken Steele <ksteele@gmail.com>

I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: a979a680e1dc2fee8461401335cfb5dda8cfdd98
I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 9827068382ca946fe1387ed83f747ae509fcf229
I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: ebbeb45c7dc266260e1fad6bdb54a7041f8aeed4
I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 2f6fd3cf46c8ca0bb98810191578278f1df87aa3

Signed-off-by: Ken Steele <ksteele@gmail.com>

* fix unit tests and code coverage for CI

* DCO Remediation Commit for Ken Steele <ksteele@gmail.com>

I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 5e61bf11139a2133978db2c8d306be6289aed732

Signed-off-by: Ken Steele <ksteele@gmail.com>

* fix CI example test - mlx_whisper_example.py defaults to tests/data/audio/sample_10s.mp3 if no args specified.

Signed-off-by: Ken Steele <ksteele@gmail.com>

* refactor: centralize audio file extensions and MIME types in base_models.py

- Move audio file extensions from CLI hardcoded set to FormatToExtensions[InputFormat.AUDIO]
- Add support for additional audio formats: m4a, aac, ogg, flac, mp4, avi, mov
- Update FormatToMimeType mapping to include MIME types for all audio formats
- Update CLI auto-detection to use centralized FormatToExtensions mapping
- Add comprehensive tests for audio file auto-detection and pipeline selection
- Ensure explicit pipeline choices are not overridden by auto-detection

Fixes issue where only .mp3 and .wav files were processed as audio despite
CLI auto-detection working for all formats. The document converter now
properly recognizes all audio formats through MIME type detection.

Addresses review comments:
- Centralizes audio extensions in base_models.py as suggested
- Maintains existing auto-detection behavior while using centralized data
- Adds proper test coverage for the audio detection functionality

All examples and tests pass with the new centralized approach.
All audio formats (mp3, wav, m4a, aac, ogg, flac, mp4, avi, mov) now work correctly.

Signed-off-by: Ken Steele <ksteele@gmail.com>

* feat: address reviewer feedback - improve CLI auto-detection and add explicit model options

Review feedback addressed:
1. Fix CLI auto-detection to only switch to ASR pipeline when ALL files are audio
   - Previously switched if ANY file was audio, now requires ALL files to be audio
   - Added warning for mixed file types with guidance to use --pipeline asr

2. Add explicit WHISPER_X_MLX and WHISPER_X_NATIVE model options
   - Users can now force specific implementations if desired
   - Auto-selecting models (WHISPER_BASE, etc.) still choose best for hardware
   - Added 12 new explicit model options: _MLX and _NATIVE variants for each size

CLI now supports:
- Auto-selecting: whisper_tiny, whisper_base, etc. (choose best for hardware)
- Explicit MLX: whisper_tiny_mlx, whisper_base_mlx, etc. (force MLX)
- Explicit Native: whisper_tiny_native, whisper_base_native, etc. (force native)

Addresses reviewer comments from @dolfim-ibm

Signed-off-by: Ken Steele <ksteele@gmail.com>

* DCO Remediation Commit for Ken Steele <ksteele@gmail.com>

I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: c60e72d2b5
I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 94803317a3
I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 21905e8ace
I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 96c669d155
I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 8371c060ea

Signed-off-by: Ken Steele <ksteele@gmail.com>

* test(asr): add coverage for MLX options, pipeline helpers, and VLM prompts

- tests/test_asr_mlx_whisper.py: verify explicit MLX options (framework, repo ids)
- tests/test_asr_pipeline.py: cover _has_text/_determine_status and backend support with proper InputDocument/NoOpBackend wiring
- tests/test_interfaces.py: add BaseVlmPageModel.formulate_prompt tests (RAW/NONE/CHAT, invalid style), with minimal InlineVlmOptions scaffold

Improves reliability of ASR and VLM components by validating configuration paths and helper logic.

Signed-off-by: Ken Steele <ksteele@gmail.com>

* test(asr): broaden coverage for model selection, pipeline flows, and VLM prompts

- tests/test_asr_mlx_whisper.py
  - Add MLX/native selector coverage across all Whisper sizes
  - Validate repo_id choices under MLX and Native paths
  - Cover fallback path when MPS unavailable and mlx_whisper missing

- tests/test_asr_pipeline.py
  - Relax silent-audio assertion to accept PARTIAL_SUCCESS or SUCCESS
  - Force CPU native path in helper tests to avoid torch in device selection
  - Add language handling tests for native/MLX transcribe
  - Cover native run success (BytesIO) and failure (exception) branches
  - Cover MLX run success/failure branches with mocked transcribe
  - Add init path coverage with artifacts_path

- tests/test_interfaces.py
  - Add focused VLM prompt tests (NONE/CHAT variants)

Result: all tests passing with significantly improved coverage for ASR model selectors, pipeline execution paths, and VLM prompt formulation.

Signed-off-by: Ken Steele <ksteele@gmail.com>

* simplify ASR model settings (no pipeline detection needed)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* clean up disk space in runners

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Ken Steele <ksteele@gmail.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Ken Steele
2025-10-20 23:05:59 -07:00
committed by GitHub
parent a5af082d82
commit 657ce8b01c
29 changed files with 2016 additions and 71 deletions

View File

@@ -32,11 +32,23 @@ from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import (
WHISPER_BASE,
WHISPER_BASE_MLX,
WHISPER_BASE_NATIVE,
WHISPER_LARGE,
WHISPER_LARGE_MLX,
WHISPER_LARGE_NATIVE,
WHISPER_MEDIUM,
WHISPER_MEDIUM_MLX,
WHISPER_MEDIUM_NATIVE,
WHISPER_SMALL,
WHISPER_SMALL_MLX,
WHISPER_SMALL_NATIVE,
WHISPER_TINY,
WHISPER_TINY_MLX,
WHISPER_TINY_NATIVE,
WHISPER_TURBO,
WHISPER_TURBO_MLX,
WHISPER_TURBO_NATIVE,
AsrModelType,
)
from docling.datamodel.base_models import (
@@ -611,6 +623,7 @@ def convert( # noqa: C901
ocr_options.psm = psm
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
# pipeline_options: PaginatedPipelineOptions
pipeline_options: PipelineOptions
@@ -747,42 +760,74 @@ def convert( # noqa: C901
InputFormat.IMAGE: pdf_format_option,
}
elif pipeline == ProcessingPipeline.ASR:
pipeline_options = AsrPipelineOptions(
# enable_remote_services=enable_remote_services,
# artifacts_path = artifacts_path
)
# Set ASR options
asr_pipeline_options = AsrPipelineOptions(
accelerator_options=AcceleratorOptions(
device=device,
num_threads=num_threads,
),
# enable_remote_services=enable_remote_services,
# artifacts_path = artifacts_path
)
if asr_model == AsrModelType.WHISPER_TINY:
pipeline_options.asr_options = WHISPER_TINY
elif asr_model == AsrModelType.WHISPER_SMALL:
pipeline_options.asr_options = WHISPER_SMALL
elif asr_model == AsrModelType.WHISPER_MEDIUM:
pipeline_options.asr_options = WHISPER_MEDIUM
elif asr_model == AsrModelType.WHISPER_BASE:
pipeline_options.asr_options = WHISPER_BASE
elif asr_model == AsrModelType.WHISPER_LARGE:
pipeline_options.asr_options = WHISPER_LARGE
elif asr_model == AsrModelType.WHISPER_TURBO:
pipeline_options.asr_options = WHISPER_TURBO
else:
_log.error(f"{asr_model} is not known")
raise ValueError(f"{asr_model} is not known")
# Auto-selecting models (choose best implementation for hardware)
if asr_model == AsrModelType.WHISPER_TINY:
asr_pipeline_options.asr_options = WHISPER_TINY
elif asr_model == AsrModelType.WHISPER_SMALL:
asr_pipeline_options.asr_options = WHISPER_SMALL
elif asr_model == AsrModelType.WHISPER_MEDIUM:
asr_pipeline_options.asr_options = WHISPER_MEDIUM
elif asr_model == AsrModelType.WHISPER_BASE:
asr_pipeline_options.asr_options = WHISPER_BASE
elif asr_model == AsrModelType.WHISPER_LARGE:
asr_pipeline_options.asr_options = WHISPER_LARGE
elif asr_model == AsrModelType.WHISPER_TURBO:
asr_pipeline_options.asr_options = WHISPER_TURBO
_log.info(f"pipeline_options: {pipeline_options}")
# Explicit MLX models (force MLX implementation)
elif asr_model == AsrModelType.WHISPER_TINY_MLX:
asr_pipeline_options.asr_options = WHISPER_TINY_MLX
elif asr_model == AsrModelType.WHISPER_SMALL_MLX:
asr_pipeline_options.asr_options = WHISPER_SMALL_MLX
elif asr_model == AsrModelType.WHISPER_MEDIUM_MLX:
asr_pipeline_options.asr_options = WHISPER_MEDIUM_MLX
elif asr_model == AsrModelType.WHISPER_BASE_MLX:
asr_pipeline_options.asr_options = WHISPER_BASE_MLX
elif asr_model == AsrModelType.WHISPER_LARGE_MLX:
asr_pipeline_options.asr_options = WHISPER_LARGE_MLX
elif asr_model == AsrModelType.WHISPER_TURBO_MLX:
asr_pipeline_options.asr_options = WHISPER_TURBO_MLX
audio_format_option = AudioFormatOption(
pipeline_cls=AsrPipeline,
pipeline_options=pipeline_options,
)
# Explicit Native models (force native implementation)
elif asr_model == AsrModelType.WHISPER_TINY_NATIVE:
asr_pipeline_options.asr_options = WHISPER_TINY_NATIVE
elif asr_model == AsrModelType.WHISPER_SMALL_NATIVE:
asr_pipeline_options.asr_options = WHISPER_SMALL_NATIVE
elif asr_model == AsrModelType.WHISPER_MEDIUM_NATIVE:
asr_pipeline_options.asr_options = WHISPER_MEDIUM_NATIVE
elif asr_model == AsrModelType.WHISPER_BASE_NATIVE:
asr_pipeline_options.asr_options = WHISPER_BASE_NATIVE
elif asr_model == AsrModelType.WHISPER_LARGE_NATIVE:
asr_pipeline_options.asr_options = WHISPER_LARGE_NATIVE
elif asr_model == AsrModelType.WHISPER_TURBO_NATIVE:
asr_pipeline_options.asr_options = WHISPER_TURBO_NATIVE
format_options = {
InputFormat.AUDIO: audio_format_option,
}
else:
_log.error(f"{asr_model} is not known")
raise ValueError(f"{asr_model} is not known")
_log.info(f"ASR pipeline_options: {asr_pipeline_options}")
audio_format_option = AudioFormatOption(
pipeline_cls=AsrPipeline,
pipeline_options=asr_pipeline_options,
)
format_options[InputFormat.AUDIO] = audio_format_option
# Common options for all pipelines
if artifacts_path is not None:
pipeline_options.artifacts_path = artifacts_path
# audio_pipeline_options.artifacts_path = artifacts_path
asr_pipeline_options.artifacts_path = artifacts_path
doc_converter = DocumentConverter(
allowed_formats=from_formats,

View File

@@ -10,13 +10,394 @@ from docling.datamodel.pipeline_options_asr_model import (
# AsrResponseFormat,
# ApiAsrOptions,
InferenceAsrFramework,
InlineAsrMlxWhisperOptions,
InlineAsrNativeWhisperOptions,
TransformersModelType,
)
_log = logging.getLogger(__name__)
WHISPER_TINY = InlineAsrNativeWhisperOptions(
def _get_whisper_tiny_model():
"""
Get the best Whisper Tiny model for the current hardware.
Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Tiny.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="tiny",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperature=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_TINY = _get_whisper_tiny_model()
def _get_whisper_small_model():
"""
Get the best Whisper Small model for the current hardware.
Automatically selects MLX Whisper Small for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Small.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-small-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="small",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperature=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_SMALL = _get_whisper_small_model()
def _get_whisper_medium_model():
"""
Get the best Whisper Medium model for the current hardware.
Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Medium.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-medium-mlx-8bit",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperature=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_MEDIUM = _get_whisper_medium_model()
def _get_whisper_base_model():
"""
Get the best Whisper Base model for the current hardware.
Automatically selects MLX Whisper Base for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Base.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-base-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="base",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperature=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_BASE = _get_whisper_base_model()
def _get_whisper_large_model():
"""
Get the best Whisper Large model for the current hardware.
Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Large.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-large-mlx-8bit",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperature=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_LARGE = _get_whisper_large_model()
def _get_whisper_turbo_model():
"""
Get the best Whisper Turbo model for the current hardware.
Automatically selects MLX Whisper Turbo for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Turbo.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-turbo",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="turbo",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperature=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_TURBO = _get_whisper_turbo_model()
# Explicit MLX Whisper model options for users who want to force MLX usage
WHISPER_TINY_MLX = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
WHISPER_SMALL_MLX = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-small-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
WHISPER_MEDIUM_MLX = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-medium-mlx-8bit",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
WHISPER_BASE_MLX = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-base-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
WHISPER_LARGE_MLX = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-large-mlx-8bit",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
WHISPER_TURBO_MLX = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-turbo",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
# Explicit Native Whisper model options for users who want to force native usage
WHISPER_TINY_NATIVE = InlineAsrNativeWhisperOptions(
repo_id="tiny",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -27,7 +408,7 @@ WHISPER_TINY = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_SMALL = InlineAsrNativeWhisperOptions(
WHISPER_SMALL_NATIVE = InlineAsrNativeWhisperOptions(
repo_id="small",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -38,7 +419,7 @@ WHISPER_SMALL = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
WHISPER_MEDIUM_NATIVE = InlineAsrNativeWhisperOptions(
repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -49,7 +430,7 @@ WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_BASE = InlineAsrNativeWhisperOptions(
WHISPER_BASE_NATIVE = InlineAsrNativeWhisperOptions(
repo_id="base",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -60,7 +441,7 @@ WHISPER_BASE = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_LARGE = InlineAsrNativeWhisperOptions(
WHISPER_LARGE_NATIVE = InlineAsrNativeWhisperOptions(
repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -71,7 +452,7 @@ WHISPER_LARGE = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_TURBO = InlineAsrNativeWhisperOptions(
WHISPER_TURBO_NATIVE = InlineAsrNativeWhisperOptions(
repo_id="turbo",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -82,11 +463,32 @@ WHISPER_TURBO = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
# Note: The main WHISPER_* models (WHISPER_TURBO, WHISPER_BASE, etc.) automatically
# select the best implementation (MLX on Apple Silicon, Native elsewhere).
# Use the explicit _MLX or _NATIVE variants if you need to force a specific implementation.
class AsrModelType(str, Enum):
# Auto-selecting models (choose best implementation for hardware)
WHISPER_TINY = "whisper_tiny"
WHISPER_SMALL = "whisper_small"
WHISPER_MEDIUM = "whisper_medium"
WHISPER_BASE = "whisper_base"
WHISPER_LARGE = "whisper_large"
WHISPER_TURBO = "whisper_turbo"
# Explicit MLX models (force MLX implementation)
WHISPER_TINY_MLX = "whisper_tiny_mlx"
WHISPER_SMALL_MLX = "whisper_small_mlx"
WHISPER_MEDIUM_MLX = "whisper_medium_mlx"
WHISPER_BASE_MLX = "whisper_base_mlx"
WHISPER_LARGE_MLX = "whisper_large_mlx"
WHISPER_TURBO_MLX = "whisper_turbo_mlx"
# Explicit Native models (force native implementation)
WHISPER_TINY_NATIVE = "whisper_tiny_native"
WHISPER_SMALL_NATIVE = "whisper_small_native"
WHISPER_MEDIUM_NATIVE = "whisper_medium_native"
WHISPER_BASE_NATIVE = "whisper_base_native"
WHISPER_LARGE_NATIVE = "whisper_large_native"
WHISPER_TURBO_NATIVE = "whisper_turbo_native"

View File

@@ -94,7 +94,7 @@ FormatToExtensions: dict[InputFormat, list[str]] = {
InputFormat.XML_USPTO: ["xml", "txt"],
InputFormat.METS_GBS: ["tar.gz"],
InputFormat.JSON_DOCLING: ["json"],
InputFormat.AUDIO: ["wav", "mp3"],
InputFormat.AUDIO: ["wav", "mp3", "m4a", "aac", "ogg", "flac", "mp4", "avi", "mov"],
InputFormat.VTT: ["vtt"],
}
@@ -128,7 +128,22 @@ FormatToMimeType: dict[InputFormat, list[str]] = {
InputFormat.XML_USPTO: ["application/xml", "text/plain"],
InputFormat.METS_GBS: ["application/mets+xml"],
InputFormat.JSON_DOCLING: ["application/json"],
InputFormat.AUDIO: ["audio/x-wav", "audio/mpeg", "audio/wav", "audio/mp3"],
InputFormat.AUDIO: [
"audio/x-wav",
"audio/mpeg",
"audio/wav",
"audio/mp3",
"audio/mp4",
"audio/m4a",
"audio/aac",
"audio/ogg",
"audio/flac",
"audio/x-flac",
"video/mp4",
"video/avi",
"video/x-msvideo",
"video/quicktime",
],
InputFormat.VTT: ["text/vtt"],
}

View File

@@ -17,7 +17,7 @@ class BaseAsrOptions(BaseModel):
class InferenceAsrFramework(str, Enum):
# MLX = "mlx" # disabled for now
MLX = "mlx"
# TRANSFORMERS = "transformers" # disabled for now
WHISPER = "whisper"
@@ -55,3 +55,23 @@ class InlineAsrNativeWhisperOptions(InlineAsrOptions):
AcceleratorDevice.CUDA,
]
word_timestamps: bool = True
class InlineAsrMlxWhisperOptions(InlineAsrOptions):
"""
MLX Whisper options for Apple Silicon optimization.
Uses mlx-whisper library for efficient inference on Apple Silicon devices.
"""
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
language: str = "en"
task: str = "transcribe" # "transcribe" or "translate"
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.MPS, # MLX is optimized for Apple Silicon
]
word_timestamps: bool = True
no_speech_threshold: float = 0.6 # Threshold for detecting speech
logprob_threshold: float = -1.0 # Log probability threshold
compression_ratio_threshold: float = 2.4 # Compression ratio threshold

View File

@@ -4,7 +4,7 @@ import re
import tempfile
from io import BytesIO
from pathlib import Path
from typing import List, Optional, Union, cast
from typing import TYPE_CHECKING, List, Optional, Union, cast
from docling_core.types.doc import DoclingDocument, DocumentOrigin
@@ -32,6 +32,7 @@ from docling.datamodel.pipeline_options import (
AsrPipelineOptions,
)
from docling.datamodel.pipeline_options_asr_model import (
InlineAsrMlxWhisperOptions,
InlineAsrNativeWhisperOptions,
# AsrResponseFormat,
InlineAsrOptions,
@@ -228,22 +229,157 @@ class _NativeWhisperModel:
return convo
class _MlxWhisperModel:
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
asr_options: InlineAsrMlxWhisperOptions,
):
"""
Transcriber using MLX Whisper for Apple Silicon optimization.
"""
self.enabled = enabled
_log.info(f"artifacts-path: {artifacts_path}")
_log.info(f"accelerator_options: {accelerator_options}")
if self.enabled:
try:
import mlx_whisper # type: ignore
except ImportError:
raise ImportError(
"mlx-whisper is not installed. Please install it via `pip install mlx-whisper` or do `uv sync --extra asr`."
)
self.asr_options = asr_options
self.mlx_whisper = mlx_whisper
self.device = decide_device(
accelerator_options.device,
supported_devices=asr_options.supported_devices,
)
_log.info(f"Available device for MLX Whisper: {self.device}")
self.model_name = asr_options.repo_id
_log.info(f"loading _MlxWhisperModel({self.model_name})")
# MLX Whisper models are loaded differently - they use HuggingFace repos
self.model_path = self.model_name
# Store MLX-specific options
self.language = asr_options.language
self.task = asr_options.task
self.word_timestamps = asr_options.word_timestamps
self.no_speech_threshold = asr_options.no_speech_threshold
self.logprob_threshold = asr_options.logprob_threshold
self.compression_ratio_threshold = asr_options.compression_ratio_threshold
def run(self, conv_res: ConversionResult) -> ConversionResult:
audio_path: Path = Path(conv_res.input.file).resolve()
try:
conversation = self.transcribe(audio_path)
# Ensure we have a proper DoclingDocument
origin = DocumentOrigin(
filename=conv_res.input.file.name or "audio.wav",
mimetype="audio/x-wav",
binary_hash=conv_res.input.document_hash,
)
conv_res.document = DoclingDocument(
name=conv_res.input.file.stem or "audio.wav", origin=origin
)
for citem in conversation:
conv_res.document.add_text(
label=DocItemLabel.TEXT, text=citem.to_string()
)
conv_res.status = ConversionStatus.SUCCESS
return conv_res
except Exception as exc:
_log.error(f"MLX Audio transcription has an error: {exc}")
conv_res.status = ConversionStatus.FAILURE
return conv_res
def transcribe(self, fpath: Path) -> list[_ConversationItem]:
"""
Transcribe audio using MLX Whisper.
Args:
fpath: Path to audio file
Returns:
List of conversation items with timestamps
"""
result = self.mlx_whisper.transcribe(
str(fpath),
path_or_hf_repo=self.model_path,
language=self.language,
task=self.task,
word_timestamps=self.word_timestamps,
no_speech_threshold=self.no_speech_threshold,
logprob_threshold=self.logprob_threshold,
compression_ratio_threshold=self.compression_ratio_threshold,
)
convo: list[_ConversationItem] = []
# MLX Whisper returns segments similar to native Whisper
for segment in result.get("segments", []):
item = _ConversationItem(
start_time=segment.get("start"),
end_time=segment.get("end"),
text=segment.get("text", "").strip(),
words=[],
)
# Add word-level timestamps if available
if self.word_timestamps and "words" in segment:
item.words = []
for word_data in segment["words"]:
item.words.append(
_ConversationWord(
start_time=word_data.get("start"),
end_time=word_data.get("end"),
text=word_data.get("word", ""),
)
)
convo.append(item)
return convo
class AsrPipeline(BasePipeline):
def __init__(self, pipeline_options: AsrPipelineOptions):
super().__init__(pipeline_options)
self.keep_backend = True
self.pipeline_options: AsrPipelineOptions = pipeline_options
self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
asr_options: InlineAsrNativeWhisperOptions = (
native_asr_options: InlineAsrNativeWhisperOptions = (
self.pipeline_options.asr_options
)
self._model = _NativeWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
asr_options=asr_options,
asr_options=native_asr_options,
)
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
mlx_asr_options: InlineAsrMlxWhisperOptions = (
self.pipeline_options.asr_options
)
self._model = _MlxWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
asr_options=mlx_asr_options,
)
else:
_log.error(f"No model support for {self.pipeline_options.asr_options}")