diff --git a/docling/cli/main.py b/docling/cli/main.py index bcae0648..8bfe8989 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -611,17 +611,27 @@ def convert( # noqa: C901 ocr_options.psm = psm accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) - + # Auto-detect pipeline based on input file formats if pipeline == ProcessingPipeline.STANDARD: # Check if any input files are audio files by extension - audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'} + audio_extensions = { + ".mp3", + ".wav", + ".m4a", + ".aac", + ".ogg", + ".flac", + ".mp4", + ".avi", + ".mov", + } for path in input_doc_paths: if path.suffix.lower() in audio_extensions: pipeline = ProcessingPipeline.ASR _log.info(f"Auto-detected ASR pipeline for audio file: {path}") break - + # pipeline_options: PaginatedPipelineOptions pipeline_options: PipelineOptions diff --git a/docling/datamodel/asr_model_specs.py b/docling/datamodel/asr_model_specs.py index 8b7f7a3c..e03f9dbf 100644 --- a/docling/datamodel/asr_model_specs.py +++ b/docling/datamodel/asr_model_specs.py @@ -10,34 +10,37 @@ from docling.datamodel.pipeline_options_asr_model import ( # AsrResponseFormat, # ApiAsrOptions, InferenceAsrFramework, - InlineAsrNativeWhisperOptions, InlineAsrMlxWhisperOptions, + InlineAsrNativeWhisperOptions, TransformersModelType, ) _log = logging.getLogger(__name__) + def _get_whisper_tiny_model(): """ Get the best Whisper Tiny model for the current hardware. - + Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available, otherwise falls back to native Whisper Tiny. """ # Check if MPS is available (Apple Silicon) try: import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() except ImportError: has_mps = False - + # Check if mlx-whisper is available try: import mlx_whisper # type: ignore + has_mlx_whisper = True except ImportError: has_mlx_whisper = False - + # Use MLX Whisper if both MPS and mlx-whisper are available if has_mps and has_mlx_whisper: return InlineAsrMlxWhisperOptions( @@ -66,27 +69,30 @@ def _get_whisper_tiny_model(): # Create the model instance WHISPER_TINY = _get_whisper_tiny_model() + def _get_whisper_small_model(): """ Get the best Whisper Small model for the current hardware. - + Automatically selects MLX Whisper Small for Apple Silicon (MPS) if available, otherwise falls back to native Whisper Small. """ # Check if MPS is available (Apple Silicon) try: import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() except ImportError: has_mps = False - + # Check if mlx-whisper is available try: import mlx_whisper # type: ignore + has_mlx_whisper = True except ImportError: has_mlx_whisper = False - + # Use MLX Whisper if both MPS and mlx-whisper are available if has_mps and has_mlx_whisper: return InlineAsrMlxWhisperOptions( @@ -115,27 +121,30 @@ def _get_whisper_small_model(): # Create the model instance WHISPER_SMALL = _get_whisper_small_model() + def _get_whisper_medium_model(): """ Get the best Whisper Medium model for the current hardware. - + Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available, otherwise falls back to native Whisper Medium. """ # Check if MPS is available (Apple Silicon) try: import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() except ImportError: has_mps = False - + # Check if mlx-whisper is available try: import mlx_whisper # type: ignore + has_mlx_whisper = True except ImportError: has_mlx_whisper = False - + # Use MLX Whisper if both MPS and mlx-whisper are available if has_mps and has_mlx_whisper: return InlineAsrMlxWhisperOptions( @@ -164,27 +173,30 @@ def _get_whisper_medium_model(): # Create the model instance WHISPER_MEDIUM = _get_whisper_medium_model() + def _get_whisper_base_model(): """ Get the best Whisper Base model for the current hardware. - + Automatically selects MLX Whisper Base for Apple Silicon (MPS) if available, otherwise falls back to native Whisper Base. """ # Check if MPS is available (Apple Silicon) try: import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() except ImportError: has_mps = False - + # Check if mlx-whisper is available try: import mlx_whisper # type: ignore + has_mlx_whisper = True except ImportError: has_mlx_whisper = False - + # Use MLX Whisper if both MPS and mlx-whisper are available if has_mps and has_mlx_whisper: return InlineAsrMlxWhisperOptions( @@ -213,27 +225,30 @@ def _get_whisper_base_model(): # Create the model instance WHISPER_BASE = _get_whisper_base_model() + def _get_whisper_large_model(): """ Get the best Whisper Large model for the current hardware. - + Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available, otherwise falls back to native Whisper Large. """ # Check if MPS is available (Apple Silicon) try: import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() except ImportError: has_mps = False - + # Check if mlx-whisper is available try: import mlx_whisper # type: ignore + has_mlx_whisper = True except ImportError: has_mlx_whisper = False - + # Use MLX Whisper if both MPS and mlx-whisper are available if has_mps and has_mlx_whisper: return InlineAsrMlxWhisperOptions( @@ -262,27 +277,30 @@ def _get_whisper_large_model(): # Create the model instance WHISPER_LARGE = _get_whisper_large_model() + def _get_whisper_turbo_model(): """ Get the best Whisper Turbo model for the current hardware. - + Automatically selects MLX Whisper Turbo for Apple Silicon (MPS) if available, otherwise falls back to native Whisper Turbo. """ # Check if MPS is available (Apple Silicon) try: import torch + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() except ImportError: has_mps = False - + # Check if mlx-whisper is available try: import mlx_whisper # type: ignore + has_mlx_whisper = True except ImportError: has_mlx_whisper = False - + # Use MLX Whisper if both MPS and mlx-whisper are available if has_mps and has_mlx_whisper: return InlineAsrMlxWhisperOptions( diff --git a/docling/datamodel/pipeline_options_asr_model.py b/docling/datamodel/pipeline_options_asr_model.py index 1fec9dfe..24b161ad 100644 --- a/docling/datamodel/pipeline_options_asr_model.py +++ b/docling/datamodel/pipeline_options_asr_model.py @@ -60,9 +60,10 @@ class InlineAsrNativeWhisperOptions(InlineAsrOptions): class InlineAsrMlxWhisperOptions(InlineAsrOptions): """ MLX Whisper options for Apple Silicon optimization. - + Uses mlx-whisper library for efficient inference on Apple Silicon devices. """ + inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX language: str = "en" diff --git a/docling/pipeline/asr_pipeline.py b/docling/pipeline/asr_pipeline.py index 7383e00d..92b298c9 100644 --- a/docling/pipeline/asr_pipeline.py +++ b/docling/pipeline/asr_pipeline.py @@ -4,7 +4,7 @@ import re import tempfile from io import BytesIO from pathlib import Path -from typing import List, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional, Union, cast from docling_core.types.doc import DoclingDocument, DocumentOrigin @@ -32,8 +32,8 @@ from docling.datamodel.pipeline_options import ( AsrPipelineOptions, ) from docling.datamodel.pipeline_options_asr_model import ( - InlineAsrNativeWhisperOptions, InlineAsrMlxWhisperOptions, + InlineAsrNativeWhisperOptions, # AsrResponseFormat, InlineAsrOptions, ) @@ -263,7 +263,7 @@ class _MlxWhisperModel: self.model_name = asr_options.repo_id _log.info(f"loading _MlxWhisperModel({self.model_name})") - + # MLX Whisper models are loaded differently - they use HuggingFace repos self.model_path = self.model_name @@ -308,10 +308,10 @@ class _MlxWhisperModel: def transcribe(self, fpath: Path) -> list[_ConversationItem]: """ Transcribe audio using MLX Whisper. - + Args: fpath: Path to audio file - + Returns: List of conversation items with timestamps """ @@ -327,16 +327,16 @@ class _MlxWhisperModel: ) convo: list[_ConversationItem] = [] - + # MLX Whisper returns segments similar to native Whisper for segment in result.get("segments", []): item = _ConversationItem( start_time=segment.get("start"), end_time=segment.get("end"), text=segment.get("text", "").strip(), - words=[] + words=[], ) - + # Add word-level timestamps if available if self.word_timestamps and "words" in segment: item.words = [] @@ -359,26 +359,27 @@ class AsrPipeline(BasePipeline): self.keep_backend = True self.pipeline_options: AsrPipelineOptions = pipeline_options + self._model: Union[_NativeWhisperModel, _MlxWhisperModel] if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions): - asr_options: InlineAsrNativeWhisperOptions = ( + native_asr_options: InlineAsrNativeWhisperOptions = ( self.pipeline_options.asr_options ) self._model = _NativeWhisperModel( enabled=True, # must be always enabled for this pipeline to make sense. artifacts_path=self.artifacts_path, accelerator_options=pipeline_options.accelerator_options, - asr_options=asr_options, + asr_options=native_asr_options, ) elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions): - asr_options: InlineAsrMlxWhisperOptions = ( + mlx_asr_options: InlineAsrMlxWhisperOptions = ( self.pipeline_options.asr_options ) self._model = _MlxWhisperModel( enabled=True, # must be always enabled for this pipeline to make sense. artifacts_path=self.artifacts_path, accelerator_options=pipeline_options.accelerator_options, - asr_options=asr_options, + asr_options=mlx_asr_options, ) else: _log.error(f"No model support for {self.pipeline_options.asr_options}") diff --git a/docs/examples/minimal_asr_pipeline.py b/docs/examples/minimal_asr_pipeline.py index 5e08d0c2..5af2bb83 100644 --- a/docs/examples/minimal_asr_pipeline.py +++ b/docs/examples/minimal_asr_pipeline.py @@ -43,7 +43,7 @@ def get_asr_converter(): implementation for your hardware: - MLX Whisper Turbo for Apple Silicon (M1/M2/M3) with mlx-whisper installed - Native Whisper Turbo as fallback - + You can swap in another model spec from `docling.datamodel.asr_model_specs` to experiment with different model sizes. """ diff --git a/docs/examples/mlx_whisper_example.py b/docs/examples/mlx_whisper_example.py index 106eb708..71aad6f9 100644 --- a/docs/examples/mlx_whisper_example.py +++ b/docs/examples/mlx_whisper_example.py @@ -12,31 +12,31 @@ from pathlib import Path # Add the repository root to the path so we can import docling sys.path.insert(0, str(Path(__file__).parent.parent.parent)) +from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.asr_model_specs import ( - WHISPER_TINY, WHISPER_BASE, - WHISPER_SMALL, - WHISPER_MEDIUM, WHISPER_LARGE, + WHISPER_MEDIUM, + WHISPER_SMALL, + WHISPER_TINY, WHISPER_TURBO, ) -from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice from docling.datamodel.base_models import InputFormat from docling.datamodel.pipeline_options import AsrPipelineOptions +from docling.document_converter import AudioFormatOption, DocumentConverter from docling.pipeline.asr_pipeline import AsrPipeline -from docling.document_converter import DocumentConverter, AudioFormatOption def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"): """ Transcribe audio using Whisper models with automatic MLX optimization for Apple Silicon. - + Args: audio_file_path: Path to the audio file to transcribe model_size: Size of the Whisper model to use ("tiny", "base", "small", "medium", "large", "turbo") Note: MLX optimization is automatically used on Apple Silicon when available - + Returns: The transcribed text """ @@ -49,21 +49,23 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b "large": WHISPER_LARGE, "turbo": WHISPER_TURBO, } - + if model_size not in model_map: - raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}") - + raise ValueError( + f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}" + ) + asr_options = model_map[model_size] - + # Configure accelerator options for Apple Silicon accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) - + # Create pipeline options pipeline_options = AsrPipelineOptions( asr_options=asr_options, accelerator_options=accelerator_options, ) - + # Create document converter with MLX Whisper configuration converter = DocumentConverter( format_options={ @@ -73,16 +75,16 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b ) } ) - + # Run transcription result = converter.convert(Path(audio_file_path)) - + if result.status.value == "success": # Extract text from the document text_content = [] for item in result.document.texts: text_content.append(item.text) - + return "\n".join(text_content) else: raise RuntimeError(f"Transcription failed: {result.status}") @@ -95,26 +97,30 @@ def main(): print("Model sizes: tiny, base, small, medium, large, turbo") print("Example: python mlx_whisper_example.py audio.wav base") sys.exit(1) - + audio_file_path = sys.argv[1] model_size = sys.argv[2] if len(sys.argv) > 2 else "base" - + if not Path(audio_file_path).exists(): print(f"Error: Audio file '{audio_file_path}' not found.") sys.exit(1) - + try: print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...") - print("Note: MLX optimization is automatically used on Apple Silicon when available.") + print( + "Note: MLX optimization is automatically used on Apple Silicon when available." + ) print() - - transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size) - + + transcribed_text = transcribe_audio_with_mlx_whisper( + audio_file_path, model_size + ) + print("Transcription Result:") print("=" * 50) print(transcribed_text) print("=" * 50) - + except ImportError as e: print(f"Error: {e}") print("Please install mlx-whisper: pip install mlx-whisper") diff --git a/tests/test_asr_mlx_whisper.py b/tests/test_asr_mlx_whisper.py index 8d860182..b6cb2129 100644 --- a/tests/test_asr_mlx_whisper.py +++ b/tests/test_asr_mlx_whisper.py @@ -1,25 +1,27 @@ """ Test MLX Whisper integration for Apple Silicon ASR pipeline. """ -import pytest + from pathlib import Path from unittest.mock import Mock, patch +import pytest + +from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.asr_model_specs import ( - WHISPER_TINY, WHISPER_BASE, - WHISPER_SMALL, - WHISPER_MEDIUM, WHISPER_LARGE, + WHISPER_MEDIUM, + WHISPER_SMALL, + WHISPER_TINY, WHISPER_TURBO, ) +from docling.datamodel.pipeline_options import AsrPipelineOptions from docling.datamodel.pipeline_options_asr_model import ( InferenceAsrFramework, InlineAsrMlxWhisperOptions, ) -from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel -from docling.datamodel.pipeline_options import AsrPipelineOptions class TestMlxWhisperIntegration: @@ -32,7 +34,7 @@ class TestMlxWhisperIntegration: language="en", task="transcribe", ) - + assert options.inference_framework == InferenceAsrFramework.MLX assert options.repo_id == "mlx-community/whisper-tiny-mlx" assert options.language == "en" @@ -45,24 +47,24 @@ class TestMlxWhisperIntegration: # This test verifies that the models are correctly configured # In a real Apple Silicon environment with mlx-whisper installed, # these models would automatically use MLX - - # Check that the models exist and have the correct structure - assert hasattr(WHISPER_TURBO, 'inference_framework') - assert hasattr(WHISPER_TURBO, 'repo_id') - - assert hasattr(WHISPER_BASE, 'inference_framework') - assert hasattr(WHISPER_BASE, 'repo_id') - - assert hasattr(WHISPER_SMALL, 'inference_framework') - assert hasattr(WHISPER_SMALL, 'repo_id') - @patch('builtins.__import__') + # Check that the models exist and have the correct structure + assert hasattr(WHISPER_TURBO, "inference_framework") + assert hasattr(WHISPER_TURBO, "repo_id") + + assert hasattr(WHISPER_BASE, "inference_framework") + assert hasattr(WHISPER_BASE, "repo_id") + + assert hasattr(WHISPER_SMALL, "inference_framework") + assert hasattr(WHISPER_SMALL, "repo_id") + + @patch("builtins.__import__") def test_mlx_whisper_model_initialization(self, mock_import): """Test MLX Whisper model initialization.""" # Mock the mlx_whisper import mock_mlx_whisper = Mock() mock_import.return_value = mock_mlx_whisper - + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) asr_options = InlineAsrMlxWhisperOptions( repo_id="mlx-community/whisper-tiny-mlx", @@ -74,14 +76,14 @@ class TestMlxWhisperIntegration: logprob_threshold=-1.0, compression_ratio_threshold=2.4, ) - + model = _MlxWhisperModel( enabled=True, artifacts_path=None, accelerator_options=accelerator_options, asr_options=asr_options, ) - + assert model.enabled is True assert model.model_path == "mlx-community/whisper-tiny-mlx" assert model.language == "en" @@ -101,8 +103,11 @@ class TestMlxWhisperIntegration: logprob_threshold=-1.0, compression_ratio_threshold=2.4, ) - - with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")): + + with patch( + "builtins.__import__", + side_effect=ImportError("No module named 'mlx_whisper'"), + ): with pytest.raises(ImportError, match="mlx-whisper is not installed"): _MlxWhisperModel( enabled=True, @@ -111,13 +116,13 @@ class TestMlxWhisperIntegration: asr_options=asr_options, ) - @patch('builtins.__import__') + @patch("builtins.__import__") def test_mlx_whisper_transcribe(self, mock_import): """Test MLX Whisper transcription method.""" # Mock the mlx_whisper module and its transcribe function mock_mlx_whisper = Mock() mock_import.return_value = mock_mlx_whisper - + # Mock the transcribe result mock_result = { "segments": [ @@ -128,12 +133,12 @@ class TestMlxWhisperIntegration: "words": [ {"start": 0.0, "end": 0.5, "word": "Hello"}, {"start": 0.5, "end": 1.0, "word": "world"}, - ] + ], } ] } mock_mlx_whisper.transcribe.return_value = mock_result - + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) asr_options = InlineAsrMlxWhisperOptions( repo_id="mlx-community/whisper-tiny-mlx", @@ -145,18 +150,18 @@ class TestMlxWhisperIntegration: logprob_threshold=-1.0, compression_ratio_threshold=2.4, ) - + model = _MlxWhisperModel( enabled=True, artifacts_path=None, accelerator_options=accelerator_options, asr_options=asr_options, ) - + # Test transcription audio_path = Path("test_audio.wav") result = model.transcribe(audio_path) - + # Verify the result assert len(result) == 1 assert result[0].start_time == 0.0 @@ -165,7 +170,7 @@ class TestMlxWhisperIntegration: assert len(result[0].words) == 2 assert result[0].words[0].text == "Hello" assert result[0].words[1].text == "world" - + # Verify mlx_whisper.transcribe was called with correct parameters mock_mlx_whisper.transcribe.assert_called_once_with( str(audio_path), @@ -178,13 +183,13 @@ class TestMlxWhisperIntegration: compression_ratio_threshold=2.4, ) - @patch('builtins.__import__') + @patch("builtins.__import__") def test_asr_pipeline_with_mlx_whisper(self, mock_import): """Test that AsrPipeline can be initialized with MLX Whisper options.""" # Mock the mlx_whisper import mock_mlx_whisper = Mock() mock_import.return_value = mock_mlx_whisper - + accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) asr_options = InlineAsrMlxWhisperOptions( repo_id="mlx-community/whisper-tiny-mlx", @@ -200,7 +205,7 @@ class TestMlxWhisperIntegration: asr_options=asr_options, accelerator_options=accelerator_options, ) - + pipeline = AsrPipeline(pipeline_options) assert isinstance(pipeline._model, _MlxWhisperModel) assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"