added mlx-whisper example and test. update docling cli to use MLX automatically if present.

This commit is contained in:
Ken Steele
2025-10-02 04:29:30 -07:00
parent c60e72d2b5
commit 94803317a3
4 changed files with 494 additions and 30 deletions

View File

@@ -611,6 +611,17 @@ def convert( # noqa: C901
ocr_options.psm = psm
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
# Auto-detect pipeline based on input file formats
if pipeline == ProcessingPipeline.STANDARD:
# Check if any input files are audio files by extension
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'}
for path in input_doc_paths:
if path.suffix.lower() in audio_extensions:
pipeline = ProcessingPipeline.ASR
_log.info(f"Auto-detected ASR pipeline for audio file: {path}")
break
# pipeline_options: PaginatedPipelineOptions
pipeline_options: PipelineOptions
@@ -749,6 +760,10 @@ def convert( # noqa: C901
elif pipeline == ProcessingPipeline.ASR:
pipeline_options = AsrPipelineOptions(
accelerator_options=AcceleratorOptions(
device=device,
num_threads=num_threads,
),
# enable_remote_services=enable_remote_services,
# artifacts_path = artifacts_path
)

View File

@@ -17,7 +17,41 @@ from docling.datamodel.pipeline_options_asr_model import (
_log = logging.getLogger(__name__)
WHISPER_TINY = InlineAsrNativeWhisperOptions(
def _get_whisper_tiny_model():
"""
Get the best Whisper Tiny model for the current hardware.
Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Tiny.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="tiny",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -28,6 +62,10 @@ WHISPER_TINY = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_TINY = _get_whisper_tiny_model()
def _get_whisper_small_model():
"""
Get the best Whisper Small model for the current hardware.
@@ -77,7 +115,41 @@ def _get_whisper_small_model():
# Create the model instance
WHISPER_SMALL = _get_whisper_small_model()
WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
def _get_whisper_medium_model():
"""
Get the best Whisper Medium model for the current hardware.
Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Medium.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-medium-mlx-8bit",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -88,6 +160,10 @@ WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_MEDIUM = _get_whisper_medium_model()
def _get_whisper_base_model():
"""
Get the best Whisper Base model for the current hardware.
@@ -137,7 +213,41 @@ def _get_whisper_base_model():
# Create the model instance
WHISPER_BASE = _get_whisper_base_model()
WHISPER_LARGE = InlineAsrNativeWhisperOptions(
def _get_whisper_large_model():
"""
Get the best Whisper Large model for the current hardware.
Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Large.
"""
# Check if MPS is available (Apple Silicon)
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
# Check if mlx-whisper is available
try:
import mlx_whisper # type: ignore
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-large-mlx-8bit",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
else:
return InlineAsrNativeWhisperOptions(
repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@@ -148,6 +258,10 @@ WHISPER_LARGE = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
# Create the model instance
WHISPER_LARGE = _get_whisper_large_model()
def _get_whisper_turbo_model():
"""
Get the best Whisper Turbo model for the current hardware.

129
docs/examples/mlx_whisper_example.py vendored Normal file
View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
"""
Example script demonstrating MLX Whisper integration for Apple Silicon.
This script shows how to use the MLX Whisper models for speech recognition
on Apple Silicon devices with optimized performance.
"""
import sys
from pathlib import Path
# Add the repository root to the path so we can import docling
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_BASE,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_LARGE,
WHISPER_TURBO,
)
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.pipeline.asr_pipeline import AsrPipeline
from docling.document_converter import DocumentConverter, AudioFormatOption
def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"):
"""
Transcribe audio using Whisper models with automatic MLX optimization for Apple Silicon.
Args:
audio_file_path: Path to the audio file to transcribe
model_size: Size of the Whisper model to use
("tiny", "base", "small", "medium", "large", "turbo")
Note: MLX optimization is automatically used on Apple Silicon when available
Returns:
The transcribed text
"""
# Select the appropriate Whisper model (automatically uses MLX on Apple Silicon)
model_map = {
"tiny": WHISPER_TINY,
"base": WHISPER_BASE,
"small": WHISPER_SMALL,
"medium": WHISPER_MEDIUM,
"large": WHISPER_LARGE,
"turbo": WHISPER_TURBO,
}
if model_size not in model_map:
raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}")
asr_options = model_map[model_size]
# Configure accelerator options for Apple Silicon
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
# Create pipeline options
pipeline_options = AsrPipelineOptions(
asr_options=asr_options,
accelerator_options=accelerator_options,
)
# Create document converter with MLX Whisper configuration
converter = DocumentConverter(
format_options={
InputFormat.AUDIO: AudioFormatOption(
pipeline_cls=AsrPipeline,
pipeline_options=pipeline_options,
)
}
)
# Run transcription
result = converter.convert(Path(audio_file_path))
if result.status.value == "success":
# Extract text from the document
text_content = []
for item in result.document.texts:
text_content.append(item.text)
return "\n".join(text_content)
else:
raise RuntimeError(f"Transcription failed: {result.status}")
def main():
"""Main function to demonstrate MLX Whisper usage."""
if len(sys.argv) < 2:
print("Usage: python mlx_whisper_example.py <audio_file_path> [model_size]")
print("Model sizes: tiny, base, small, medium, large, turbo")
print("Example: python mlx_whisper_example.py audio.wav base")
sys.exit(1)
audio_file_path = sys.argv[1]
model_size = sys.argv[2] if len(sys.argv) > 2 else "base"
if not Path(audio_file_path).exists():
print(f"Error: Audio file '{audio_file_path}' not found.")
sys.exit(1)
try:
print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...")
print("Note: MLX optimization is automatically used on Apple Silicon when available.")
print()
transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size)
print("Transcription Result:")
print("=" * 50)
print(transcribed_text)
print("=" * 50)
except ImportError as e:
print(f"Error: {e}")
print("Please install mlx-whisper: pip install mlx-whisper")
print("Or install with uv: uv sync --extra asr")
sys.exit(1)
except Exception as e:
print(f"Error during transcription: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,206 @@
"""
Test MLX Whisper integration for Apple Silicon ASR pipeline.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_BASE,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_LARGE,
WHISPER_TURBO,
)
from docling.datamodel.pipeline_options_asr_model import (
InferenceAsrFramework,
InlineAsrMlxWhisperOptions,
)
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
from docling.datamodel.pipeline_options import AsrPipelineOptions
class TestMlxWhisperIntegration:
"""Test MLX Whisper model integration."""
def test_mlx_whisper_options_creation(self):
"""Test that MLX Whisper options are created correctly."""
options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
language="en",
task="transcribe",
)
assert options.inference_framework == InferenceAsrFramework.MLX
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
assert options.language == "en"
assert options.task == "transcribe"
assert options.word_timestamps is True
assert AcceleratorDevice.MPS in options.supported_devices
def test_whisper_models_auto_select_mlx(self):
"""Test that Whisper models automatically select MLX when MPS and mlx-whisper are available."""
# This test verifies that the models are correctly configured
# In a real Apple Silicon environment with mlx-whisper installed,
# these models would automatically use MLX
# Check that the models exist and have the correct structure
assert hasattr(WHISPER_TURBO, 'inference_framework')
assert hasattr(WHISPER_TURBO, 'repo_id')
assert hasattr(WHISPER_BASE, 'inference_framework')
assert hasattr(WHISPER_BASE, 'repo_id')
assert hasattr(WHISPER_SMALL, 'inference_framework')
assert hasattr(WHISPER_SMALL, 'repo_id')
@patch('builtins.__import__')
def test_mlx_whisper_model_initialization(self, mock_import):
"""Test MLX Whisper model initialization."""
# Mock the mlx_whisper import
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
model = _MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
assert model.enabled is True
assert model.model_path == "mlx-community/whisper-tiny-mlx"
assert model.language == "en"
assert model.task == "transcribe"
assert model.word_timestamps is True
def test_mlx_whisper_model_import_error(self):
"""Test that ImportError is raised when mlx-whisper is not available."""
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
_MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
@patch('builtins.__import__')
def test_mlx_whisper_transcribe(self, mock_import):
"""Test MLX Whisper transcription method."""
# Mock the mlx_whisper module and its transcribe function
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
# Mock the transcribe result
mock_result = {
"segments": [
{
"start": 0.0,
"end": 2.5,
"text": "Hello world",
"words": [
{"start": 0.0, "end": 0.5, "word": "Hello"},
{"start": 0.5, "end": 1.0, "word": "world"},
]
}
]
}
mock_mlx_whisper.transcribe.return_value = mock_result
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
model = _MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
# Test transcription
audio_path = Path("test_audio.wav")
result = model.transcribe(audio_path)
# Verify the result
assert len(result) == 1
assert result[0].start_time == 0.0
assert result[0].end_time == 2.5
assert result[0].text == "Hello world"
assert len(result[0].words) == 2
assert result[0].words[0].text == "Hello"
assert result[0].words[1].text == "world"
# Verify mlx_whisper.transcribe was called with correct parameters
mock_mlx_whisper.transcribe.assert_called_once_with(
str(audio_path),
path_or_hf_repo="mlx-community/whisper-tiny-mlx",
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
@patch('builtins.__import__')
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
# Mock the mlx_whisper import
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
pipeline_options = AsrPipelineOptions(
asr_options=asr_options,
accelerator_options=accelerator_options,
)
pipeline = AsrPipeline(pipeline_options)
assert isinstance(pipeline._model, _MlxWhisperModel)
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"