diff --git a/docling/cli/main.py b/docling/cli/main.py index 8f1b1cd6..bcae0648 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -611,6 +611,17 @@ def convert( # noqa: C901 ocr_options.psm = psm accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) + + # Auto-detect pipeline based on input file formats + if pipeline == ProcessingPipeline.STANDARD: + # Check if any input files are audio files by extension + audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'} + for path in input_doc_paths: + if path.suffix.lower() in audio_extensions: + pipeline = ProcessingPipeline.ASR + _log.info(f"Auto-detected ASR pipeline for audio file: {path}") + break + # pipeline_options: PaginatedPipelineOptions pipeline_options: PipelineOptions @@ -749,6 +760,10 @@ def convert( # noqa: C901 elif pipeline == ProcessingPipeline.ASR: pipeline_options = AsrPipelineOptions( + accelerator_options=AcceleratorOptions( + device=device, + num_threads=num_threads, + ), # enable_remote_services=enable_remote_services, # artifacts_path = artifacts_path ) diff --git a/docling/datamodel/asr_model_specs.py b/docling/datamodel/asr_model_specs.py index eb0a536e..8b7f7a3c 100644 --- a/docling/datamodel/asr_model_specs.py +++ b/docling/datamodel/asr_model_specs.py @@ -17,16 +17,54 @@ from docling.datamodel.pipeline_options_asr_model import ( _log = logging.getLogger(__name__) -WHISPER_TINY = InlineAsrNativeWhisperOptions( - repo_id="tiny", - inference_framework=InferenceAsrFramework.WHISPER, - verbose=True, - timestamps=True, - word_timestamps=True, - temperature=0.0, - max_new_tokens=256, - max_time_chunk=30.0, -) +def _get_whisper_tiny_model(): + """ + Get the best Whisper Tiny model for the current hardware. + + Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available, + otherwise falls back to native Whisper Tiny. + """ + # Check if MPS is available (Apple Silicon) + try: + import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() + except ImportError: + has_mps = False + + # Check if mlx-whisper is available + try: + import mlx_whisper # type: ignore + has_mlx_whisper = True + except ImportError: + has_mlx_whisper = False + + # Use MLX Whisper if both MPS and mlx-whisper are available + if has_mps and has_mlx_whisper: + return InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-tiny-mlx", + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + else: + return InlineAsrNativeWhisperOptions( + repo_id="tiny", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperature=0.0, + max_new_tokens=256, + max_time_chunk=30.0, + ) + + +# Create the model instance +WHISPER_TINY = _get_whisper_tiny_model() def _get_whisper_small_model(): """ @@ -77,16 +115,54 @@ def _get_whisper_small_model(): # Create the model instance WHISPER_SMALL = _get_whisper_small_model() -WHISPER_MEDIUM = InlineAsrNativeWhisperOptions( - repo_id="medium", - inference_framework=InferenceAsrFramework.WHISPER, - verbose=True, - timestamps=True, - word_timestamps=True, - temperature=0.0, - max_new_tokens=256, - max_time_chunk=30.0, -) +def _get_whisper_medium_model(): + """ + Get the best Whisper Medium model for the current hardware. + + Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available, + otherwise falls back to native Whisper Medium. + """ + # Check if MPS is available (Apple Silicon) + try: + import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() + except ImportError: + has_mps = False + + # Check if mlx-whisper is available + try: + import mlx_whisper # type: ignore + has_mlx_whisper = True + except ImportError: + has_mlx_whisper = False + + # Use MLX Whisper if both MPS and mlx-whisper are available + if has_mps and has_mlx_whisper: + return InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-medium-mlx-8bit", + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + else: + return InlineAsrNativeWhisperOptions( + repo_id="medium", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperature=0.0, + max_new_tokens=256, + max_time_chunk=30.0, + ) + + +# Create the model instance +WHISPER_MEDIUM = _get_whisper_medium_model() def _get_whisper_base_model(): """ @@ -137,16 +213,54 @@ def _get_whisper_base_model(): # Create the model instance WHISPER_BASE = _get_whisper_base_model() -WHISPER_LARGE = InlineAsrNativeWhisperOptions( - repo_id="large", - inference_framework=InferenceAsrFramework.WHISPER, - verbose=True, - timestamps=True, - word_timestamps=True, - temperature=0.0, - max_new_tokens=256, - max_time_chunk=30.0, -) +def _get_whisper_large_model(): + """ + Get the best Whisper Large model for the current hardware. + + Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available, + otherwise falls back to native Whisper Large. + """ + # Check if MPS is available (Apple Silicon) + try: + import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() + except ImportError: + has_mps = False + + # Check if mlx-whisper is available + try: + import mlx_whisper # type: ignore + has_mlx_whisper = True + except ImportError: + has_mlx_whisper = False + + # Use MLX Whisper if both MPS and mlx-whisper are available + if has_mps and has_mlx_whisper: + return InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-large-mlx-8bit", + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + else: + return InlineAsrNativeWhisperOptions( + repo_id="large", + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperature=0.0, + max_new_tokens=256, + max_time_chunk=30.0, + ) + + +# Create the model instance +WHISPER_LARGE = _get_whisper_large_model() def _get_whisper_turbo_model(): """ diff --git a/docs/examples/mlx_whisper_example.py b/docs/examples/mlx_whisper_example.py new file mode 100644 index 00000000..106eb708 --- /dev/null +++ b/docs/examples/mlx_whisper_example.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating MLX Whisper integration for Apple Silicon. + +This script shows how to use the MLX Whisper models for speech recognition +on Apple Silicon devices with optimized performance. +""" + +import sys +from pathlib import Path + +# Add the repository root to the path so we can import docling +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from docling.datamodel.asr_model_specs import ( + WHISPER_TINY, + WHISPER_BASE, + WHISPER_SMALL, + WHISPER_MEDIUM, + WHISPER_LARGE, + WHISPER_TURBO, +) +from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice +from docling.datamodel.base_models import InputFormat +from docling.datamodel.pipeline_options import AsrPipelineOptions +from docling.pipeline.asr_pipeline import AsrPipeline +from docling.document_converter import DocumentConverter, AudioFormatOption + + +def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"): + """ + Transcribe audio using Whisper models with automatic MLX optimization for Apple Silicon. + + Args: + audio_file_path: Path to the audio file to transcribe + model_size: Size of the Whisper model to use + ("tiny", "base", "small", "medium", "large", "turbo") + Note: MLX optimization is automatically used on Apple Silicon when available + + Returns: + The transcribed text + """ + # Select the appropriate Whisper model (automatically uses MLX on Apple Silicon) + model_map = { + "tiny": WHISPER_TINY, + "base": WHISPER_BASE, + "small": WHISPER_SMALL, + "medium": WHISPER_MEDIUM, + "large": WHISPER_LARGE, + "turbo": WHISPER_TURBO, + } + + if model_size not in model_map: + raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}") + + asr_options = model_map[model_size] + + # Configure accelerator options for Apple Silicon + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) + + # Create pipeline options + pipeline_options = AsrPipelineOptions( + asr_options=asr_options, + accelerator_options=accelerator_options, + ) + + # Create document converter with MLX Whisper configuration + converter = DocumentConverter( + format_options={ + InputFormat.AUDIO: AudioFormatOption( + pipeline_cls=AsrPipeline, + pipeline_options=pipeline_options, + ) + } + ) + + # Run transcription + result = converter.convert(Path(audio_file_path)) + + if result.status.value == "success": + # Extract text from the document + text_content = [] + for item in result.document.texts: + text_content.append(item.text) + + return "\n".join(text_content) + else: + raise RuntimeError(f"Transcription failed: {result.status}") + + +def main(): + """Main function to demonstrate MLX Whisper usage.""" + if len(sys.argv) < 2: + print("Usage: python mlx_whisper_example.py [model_size]") + print("Model sizes: tiny, base, small, medium, large, turbo") + print("Example: python mlx_whisper_example.py audio.wav base") + sys.exit(1) + + audio_file_path = sys.argv[1] + model_size = sys.argv[2] if len(sys.argv) > 2 else "base" + + if not Path(audio_file_path).exists(): + print(f"Error: Audio file '{audio_file_path}' not found.") + sys.exit(1) + + try: + print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...") + print("Note: MLX optimization is automatically used on Apple Silicon when available.") + print() + + transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size) + + print("Transcription Result:") + print("=" * 50) + print(transcribed_text) + print("=" * 50) + + except ImportError as e: + print(f"Error: {e}") + print("Please install mlx-whisper: pip install mlx-whisper") + print("Or install with uv: uv sync --extra asr") + sys.exit(1) + except Exception as e: + print(f"Error during transcription: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/test_asr_mlx_whisper.py b/tests/test_asr_mlx_whisper.py new file mode 100644 index 00000000..8d860182 --- /dev/null +++ b/tests/test_asr_mlx_whisper.py @@ -0,0 +1,206 @@ +""" +Test MLX Whisper integration for Apple Silicon ASR pipeline. +""" +import pytest +from pathlib import Path +from unittest.mock import Mock, patch + +from docling.datamodel.asr_model_specs import ( + WHISPER_TINY, + WHISPER_BASE, + WHISPER_SMALL, + WHISPER_MEDIUM, + WHISPER_LARGE, + WHISPER_TURBO, +) +from docling.datamodel.pipeline_options_asr_model import ( + InferenceAsrFramework, + InlineAsrMlxWhisperOptions, +) +from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice +from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel +from docling.datamodel.pipeline_options import AsrPipelineOptions + + +class TestMlxWhisperIntegration: + """Test MLX Whisper model integration.""" + + def test_mlx_whisper_options_creation(self): + """Test that MLX Whisper options are created correctly.""" + options = InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-tiny-mlx", + language="en", + task="transcribe", + ) + + assert options.inference_framework == InferenceAsrFramework.MLX + assert options.repo_id == "mlx-community/whisper-tiny-mlx" + assert options.language == "en" + assert options.task == "transcribe" + assert options.word_timestamps is True + assert AcceleratorDevice.MPS in options.supported_devices + + def test_whisper_models_auto_select_mlx(self): + """Test that Whisper models automatically select MLX when MPS and mlx-whisper are available.""" + # This test verifies that the models are correctly configured + # In a real Apple Silicon environment with mlx-whisper installed, + # these models would automatically use MLX + + # Check that the models exist and have the correct structure + assert hasattr(WHISPER_TURBO, 'inference_framework') + assert hasattr(WHISPER_TURBO, 'repo_id') + + assert hasattr(WHISPER_BASE, 'inference_framework') + assert hasattr(WHISPER_BASE, 'repo_id') + + assert hasattr(WHISPER_SMALL, 'inference_framework') + assert hasattr(WHISPER_SMALL, 'repo_id') + + @patch('builtins.__import__') + def test_mlx_whisper_model_initialization(self, mock_import): + """Test MLX Whisper model initialization.""" + # Mock the mlx_whisper import + mock_mlx_whisper = Mock() + mock_import.return_value = mock_mlx_whisper + + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) + asr_options = InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-tiny-mlx", + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + + model = _MlxWhisperModel( + enabled=True, + artifacts_path=None, + accelerator_options=accelerator_options, + asr_options=asr_options, + ) + + assert model.enabled is True + assert model.model_path == "mlx-community/whisper-tiny-mlx" + assert model.language == "en" + assert model.task == "transcribe" + assert model.word_timestamps is True + + def test_mlx_whisper_model_import_error(self): + """Test that ImportError is raised when mlx-whisper is not available.""" + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) + asr_options = InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-tiny-mlx", + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + + with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")): + with pytest.raises(ImportError, match="mlx-whisper is not installed"): + _MlxWhisperModel( + enabled=True, + artifacts_path=None, + accelerator_options=accelerator_options, + asr_options=asr_options, + ) + + @patch('builtins.__import__') + def test_mlx_whisper_transcribe(self, mock_import): + """Test MLX Whisper transcription method.""" + # Mock the mlx_whisper module and its transcribe function + mock_mlx_whisper = Mock() + mock_import.return_value = mock_mlx_whisper + + # Mock the transcribe result + mock_result = { + "segments": [ + { + "start": 0.0, + "end": 2.5, + "text": "Hello world", + "words": [ + {"start": 0.0, "end": 0.5, "word": "Hello"}, + {"start": 0.5, "end": 1.0, "word": "world"}, + ] + } + ] + } + mock_mlx_whisper.transcribe.return_value = mock_result + + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) + asr_options = InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-tiny-mlx", + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + + model = _MlxWhisperModel( + enabled=True, + artifacts_path=None, + accelerator_options=accelerator_options, + asr_options=asr_options, + ) + + # Test transcription + audio_path = Path("test_audio.wav") + result = model.transcribe(audio_path) + + # Verify the result + assert len(result) == 1 + assert result[0].start_time == 0.0 + assert result[0].end_time == 2.5 + assert result[0].text == "Hello world" + assert len(result[0].words) == 2 + assert result[0].words[0].text == "Hello" + assert result[0].words[1].text == "world" + + # Verify mlx_whisper.transcribe was called with correct parameters + mock_mlx_whisper.transcribe.assert_called_once_with( + str(audio_path), + path_or_hf_repo="mlx-community/whisper-tiny-mlx", + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + + @patch('builtins.__import__') + def test_asr_pipeline_with_mlx_whisper(self, mock_import): + """Test that AsrPipeline can be initialized with MLX Whisper options.""" + # Mock the mlx_whisper import + mock_mlx_whisper = Mock() + mock_import.return_value = mock_mlx_whisper + + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) + asr_options = InlineAsrMlxWhisperOptions( + repo_id="mlx-community/whisper-tiny-mlx", + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + pipeline_options = AsrPipelineOptions( + asr_options=asr_options, + accelerator_options=accelerator_options, + ) + + pipeline = AsrPipeline(pipeline_options) + assert isinstance(pipeline._model, _MlxWhisperModel) + assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"