fix pre-commit checks and added proper type safety

This commit is contained in:
Ken Steele
2025-10-02 04:53:49 -07:00
parent 94803317a3
commit 21905e8ace
7 changed files with 135 additions and 94 deletions

View File

@@ -611,17 +611,27 @@ def convert( # noqa: C901
ocr_options.psm = psm ocr_options.psm = psm
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
# Auto-detect pipeline based on input file formats # Auto-detect pipeline based on input file formats
if pipeline == ProcessingPipeline.STANDARD: if pipeline == ProcessingPipeline.STANDARD:
# Check if any input files are audio files by extension # Check if any input files are audio files by extension
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'} audio_extensions = {
".mp3",
".wav",
".m4a",
".aac",
".ogg",
".flac",
".mp4",
".avi",
".mov",
}
for path in input_doc_paths: for path in input_doc_paths:
if path.suffix.lower() in audio_extensions: if path.suffix.lower() in audio_extensions:
pipeline = ProcessingPipeline.ASR pipeline = ProcessingPipeline.ASR
_log.info(f"Auto-detected ASR pipeline for audio file: {path}") _log.info(f"Auto-detected ASR pipeline for audio file: {path}")
break break
# pipeline_options: PaginatedPipelineOptions # pipeline_options: PaginatedPipelineOptions
pipeline_options: PipelineOptions pipeline_options: PipelineOptions

View File

@@ -10,34 +10,37 @@ from docling.datamodel.pipeline_options_asr_model import (
# AsrResponseFormat, # AsrResponseFormat,
# ApiAsrOptions, # ApiAsrOptions,
InferenceAsrFramework, InferenceAsrFramework,
InlineAsrNativeWhisperOptions,
InlineAsrMlxWhisperOptions, InlineAsrMlxWhisperOptions,
InlineAsrNativeWhisperOptions,
TransformersModelType, TransformersModelType,
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def _get_whisper_tiny_model(): def _get_whisper_tiny_model():
""" """
Get the best Whisper Tiny model for the current hardware. Get the best Whisper Tiny model for the current hardware.
Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available, Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Tiny. otherwise falls back to native Whisper Tiny.
""" """
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available # Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper: if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions( return InlineAsrMlxWhisperOptions(
@@ -66,27 +69,30 @@ def _get_whisper_tiny_model():
# Create the model instance # Create the model instance
WHISPER_TINY = _get_whisper_tiny_model() WHISPER_TINY = _get_whisper_tiny_model()
def _get_whisper_small_model(): def _get_whisper_small_model():
""" """
Get the best Whisper Small model for the current hardware. Get the best Whisper Small model for the current hardware.
Automatically selects MLX Whisper Small for Apple Silicon (MPS) if available, Automatically selects MLX Whisper Small for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Small. otherwise falls back to native Whisper Small.
""" """
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available # Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper: if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions( return InlineAsrMlxWhisperOptions(
@@ -115,27 +121,30 @@ def _get_whisper_small_model():
# Create the model instance # Create the model instance
WHISPER_SMALL = _get_whisper_small_model() WHISPER_SMALL = _get_whisper_small_model()
def _get_whisper_medium_model(): def _get_whisper_medium_model():
""" """
Get the best Whisper Medium model for the current hardware. Get the best Whisper Medium model for the current hardware.
Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available, Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Medium. otherwise falls back to native Whisper Medium.
""" """
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available # Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper: if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions( return InlineAsrMlxWhisperOptions(
@@ -164,27 +173,30 @@ def _get_whisper_medium_model():
# Create the model instance # Create the model instance
WHISPER_MEDIUM = _get_whisper_medium_model() WHISPER_MEDIUM = _get_whisper_medium_model()
def _get_whisper_base_model(): def _get_whisper_base_model():
""" """
Get the best Whisper Base model for the current hardware. Get the best Whisper Base model for the current hardware.
Automatically selects MLX Whisper Base for Apple Silicon (MPS) if available, Automatically selects MLX Whisper Base for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Base. otherwise falls back to native Whisper Base.
""" """
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available # Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper: if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions( return InlineAsrMlxWhisperOptions(
@@ -213,27 +225,30 @@ def _get_whisper_base_model():
# Create the model instance # Create the model instance
WHISPER_BASE = _get_whisper_base_model() WHISPER_BASE = _get_whisper_base_model()
def _get_whisper_large_model(): def _get_whisper_large_model():
""" """
Get the best Whisper Large model for the current hardware. Get the best Whisper Large model for the current hardware.
Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available, Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Large. otherwise falls back to native Whisper Large.
""" """
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available # Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper: if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions( return InlineAsrMlxWhisperOptions(
@@ -262,27 +277,30 @@ def _get_whisper_large_model():
# Create the model instance # Create the model instance
WHISPER_LARGE = _get_whisper_large_model() WHISPER_LARGE = _get_whisper_large_model()
def _get_whisper_turbo_model(): def _get_whisper_turbo_model():
""" """
Get the best Whisper Turbo model for the current hardware. Get the best Whisper Turbo model for the current hardware.
Automatically selects MLX Whisper Turbo for Apple Silicon (MPS) if available, Automatically selects MLX Whisper Turbo for Apple Silicon (MPS) if available,
otherwise falls back to native Whisper Turbo. otherwise falls back to native Whisper Turbo.
""" """
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
# Use MLX Whisper if both MPS and mlx-whisper are available # Use MLX Whisper if both MPS and mlx-whisper are available
if has_mps and has_mlx_whisper: if has_mps and has_mlx_whisper:
return InlineAsrMlxWhisperOptions( return InlineAsrMlxWhisperOptions(

View File

@@ -60,9 +60,10 @@ class InlineAsrNativeWhisperOptions(InlineAsrOptions):
class InlineAsrMlxWhisperOptions(InlineAsrOptions): class InlineAsrMlxWhisperOptions(InlineAsrOptions):
""" """
MLX Whisper options for Apple Silicon optimization. MLX Whisper options for Apple Silicon optimization.
Uses mlx-whisper library for efficient inference on Apple Silicon devices. Uses mlx-whisper library for efficient inference on Apple Silicon devices.
""" """
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
language: str = "en" language: str = "en"

View File

@@ -4,7 +4,7 @@ import re
import tempfile import tempfile
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union, cast from typing import TYPE_CHECKING, List, Optional, Union, cast
from docling_core.types.doc import DoclingDocument, DocumentOrigin from docling_core.types.doc import DoclingDocument, DocumentOrigin
@@ -32,8 +32,8 @@ from docling.datamodel.pipeline_options import (
AsrPipelineOptions, AsrPipelineOptions,
) )
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
InlineAsrNativeWhisperOptions,
InlineAsrMlxWhisperOptions, InlineAsrMlxWhisperOptions,
InlineAsrNativeWhisperOptions,
# AsrResponseFormat, # AsrResponseFormat,
InlineAsrOptions, InlineAsrOptions,
) )
@@ -263,7 +263,7 @@ class _MlxWhisperModel:
self.model_name = asr_options.repo_id self.model_name = asr_options.repo_id
_log.info(f"loading _MlxWhisperModel({self.model_name})") _log.info(f"loading _MlxWhisperModel({self.model_name})")
# MLX Whisper models are loaded differently - they use HuggingFace repos # MLX Whisper models are loaded differently - they use HuggingFace repos
self.model_path = self.model_name self.model_path = self.model_name
@@ -308,10 +308,10 @@ class _MlxWhisperModel:
def transcribe(self, fpath: Path) -> list[_ConversationItem]: def transcribe(self, fpath: Path) -> list[_ConversationItem]:
""" """
Transcribe audio using MLX Whisper. Transcribe audio using MLX Whisper.
Args: Args:
fpath: Path to audio file fpath: Path to audio file
Returns: Returns:
List of conversation items with timestamps List of conversation items with timestamps
""" """
@@ -327,16 +327,16 @@ class _MlxWhisperModel:
) )
convo: list[_ConversationItem] = [] convo: list[_ConversationItem] = []
# MLX Whisper returns segments similar to native Whisper # MLX Whisper returns segments similar to native Whisper
for segment in result.get("segments", []): for segment in result.get("segments", []):
item = _ConversationItem( item = _ConversationItem(
start_time=segment.get("start"), start_time=segment.get("start"),
end_time=segment.get("end"), end_time=segment.get("end"),
text=segment.get("text", "").strip(), text=segment.get("text", "").strip(),
words=[] words=[],
) )
# Add word-level timestamps if available # Add word-level timestamps if available
if self.word_timestamps and "words" in segment: if self.word_timestamps and "words" in segment:
item.words = [] item.words = []
@@ -359,26 +359,27 @@ class AsrPipeline(BasePipeline):
self.keep_backend = True self.keep_backend = True
self.pipeline_options: AsrPipelineOptions = pipeline_options self.pipeline_options: AsrPipelineOptions = pipeline_options
self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions): if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
asr_options: InlineAsrNativeWhisperOptions = ( native_asr_options: InlineAsrNativeWhisperOptions = (
self.pipeline_options.asr_options self.pipeline_options.asr_options
) )
self._model = _NativeWhisperModel( self._model = _NativeWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense. enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=self.artifacts_path, artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
asr_options=asr_options, asr_options=native_asr_options,
) )
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions): elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
asr_options: InlineAsrMlxWhisperOptions = ( mlx_asr_options: InlineAsrMlxWhisperOptions = (
self.pipeline_options.asr_options self.pipeline_options.asr_options
) )
self._model = _MlxWhisperModel( self._model = _MlxWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense. enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=self.artifacts_path, artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
asr_options=asr_options, asr_options=mlx_asr_options,
) )
else: else:
_log.error(f"No model support for {self.pipeline_options.asr_options}") _log.error(f"No model support for {self.pipeline_options.asr_options}")

View File

@@ -43,7 +43,7 @@ def get_asr_converter():
implementation for your hardware: implementation for your hardware:
- MLX Whisper Turbo for Apple Silicon (M1/M2/M3) with mlx-whisper installed - MLX Whisper Turbo for Apple Silicon (M1/M2/M3) with mlx-whisper installed
- Native Whisper Turbo as fallback - Native Whisper Turbo as fallback
You can swap in another model spec from `docling.datamodel.asr_model_specs` You can swap in another model spec from `docling.datamodel.asr_model_specs`
to experiment with different model sizes. to experiment with different model sizes.
""" """

View File

@@ -12,31 +12,31 @@ from pathlib import Path
# Add the repository root to the path so we can import docling # Add the repository root to the path so we can import docling
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import ( from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_BASE, WHISPER_BASE,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_LARGE, WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO, WHISPER_TURBO,
) )
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
from docling.datamodel.base_models import InputFormat from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import AsrPipelineOptions from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.document_converter import AudioFormatOption, DocumentConverter
from docling.pipeline.asr_pipeline import AsrPipeline from docling.pipeline.asr_pipeline import AsrPipeline
from docling.document_converter import DocumentConverter, AudioFormatOption
def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"): def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"):
""" """
Transcribe audio using Whisper models with automatic MLX optimization for Apple Silicon. Transcribe audio using Whisper models with automatic MLX optimization for Apple Silicon.
Args: Args:
audio_file_path: Path to the audio file to transcribe audio_file_path: Path to the audio file to transcribe
model_size: Size of the Whisper model to use model_size: Size of the Whisper model to use
("tiny", "base", "small", "medium", "large", "turbo") ("tiny", "base", "small", "medium", "large", "turbo")
Note: MLX optimization is automatically used on Apple Silicon when available Note: MLX optimization is automatically used on Apple Silicon when available
Returns: Returns:
The transcribed text The transcribed text
""" """
@@ -49,21 +49,23 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b
"large": WHISPER_LARGE, "large": WHISPER_LARGE,
"turbo": WHISPER_TURBO, "turbo": WHISPER_TURBO,
} }
if model_size not in model_map: if model_size not in model_map:
raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}") raise ValueError(
f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}"
)
asr_options = model_map[model_size] asr_options = model_map[model_size]
# Configure accelerator options for Apple Silicon # Configure accelerator options for Apple Silicon
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
# Create pipeline options # Create pipeline options
pipeline_options = AsrPipelineOptions( pipeline_options = AsrPipelineOptions(
asr_options=asr_options, asr_options=asr_options,
accelerator_options=accelerator_options, accelerator_options=accelerator_options,
) )
# Create document converter with MLX Whisper configuration # Create document converter with MLX Whisper configuration
converter = DocumentConverter( converter = DocumentConverter(
format_options={ format_options={
@@ -73,16 +75,16 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b
) )
} }
) )
# Run transcription # Run transcription
result = converter.convert(Path(audio_file_path)) result = converter.convert(Path(audio_file_path))
if result.status.value == "success": if result.status.value == "success":
# Extract text from the document # Extract text from the document
text_content = [] text_content = []
for item in result.document.texts: for item in result.document.texts:
text_content.append(item.text) text_content.append(item.text)
return "\n".join(text_content) return "\n".join(text_content)
else: else:
raise RuntimeError(f"Transcription failed: {result.status}") raise RuntimeError(f"Transcription failed: {result.status}")
@@ -95,26 +97,30 @@ def main():
print("Model sizes: tiny, base, small, medium, large, turbo") print("Model sizes: tiny, base, small, medium, large, turbo")
print("Example: python mlx_whisper_example.py audio.wav base") print("Example: python mlx_whisper_example.py audio.wav base")
sys.exit(1) sys.exit(1)
audio_file_path = sys.argv[1] audio_file_path = sys.argv[1]
model_size = sys.argv[2] if len(sys.argv) > 2 else "base" model_size = sys.argv[2] if len(sys.argv) > 2 else "base"
if not Path(audio_file_path).exists(): if not Path(audio_file_path).exists():
print(f"Error: Audio file '{audio_file_path}' not found.") print(f"Error: Audio file '{audio_file_path}' not found.")
sys.exit(1) sys.exit(1)
try: try:
print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...") print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...")
print("Note: MLX optimization is automatically used on Apple Silicon when available.") print(
"Note: MLX optimization is automatically used on Apple Silicon when available."
)
print() print()
transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size) transcribed_text = transcribe_audio_with_mlx_whisper(
audio_file_path, model_size
)
print("Transcription Result:") print("Transcription Result:")
print("=" * 50) print("=" * 50)
print(transcribed_text) print(transcribed_text)
print("=" * 50) print("=" * 50)
except ImportError as e: except ImportError as e:
print(f"Error: {e}") print(f"Error: {e}")
print("Please install mlx-whisper: pip install mlx-whisper") print("Please install mlx-whisper: pip install mlx-whisper")

View File

@@ -1,25 +1,27 @@
""" """
Test MLX Whisper integration for Apple Silicon ASR pipeline. Test MLX Whisper integration for Apple Silicon ASR pipeline.
""" """
import pytest
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import ( from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_BASE, WHISPER_BASE,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_LARGE, WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO, WHISPER_TURBO,
) )
from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
InferenceAsrFramework, InferenceAsrFramework,
InlineAsrMlxWhisperOptions, InlineAsrMlxWhisperOptions,
) )
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
from docling.datamodel.pipeline_options import AsrPipelineOptions
class TestMlxWhisperIntegration: class TestMlxWhisperIntegration:
@@ -32,7 +34,7 @@ class TestMlxWhisperIntegration:
language="en", language="en",
task="transcribe", task="transcribe",
) )
assert options.inference_framework == InferenceAsrFramework.MLX assert options.inference_framework == InferenceAsrFramework.MLX
assert options.repo_id == "mlx-community/whisper-tiny-mlx" assert options.repo_id == "mlx-community/whisper-tiny-mlx"
assert options.language == "en" assert options.language == "en"
@@ -45,24 +47,24 @@ class TestMlxWhisperIntegration:
# This test verifies that the models are correctly configured # This test verifies that the models are correctly configured
# In a real Apple Silicon environment with mlx-whisper installed, # In a real Apple Silicon environment with mlx-whisper installed,
# these models would automatically use MLX # these models would automatically use MLX
# Check that the models exist and have the correct structure
assert hasattr(WHISPER_TURBO, 'inference_framework')
assert hasattr(WHISPER_TURBO, 'repo_id')
assert hasattr(WHISPER_BASE, 'inference_framework')
assert hasattr(WHISPER_BASE, 'repo_id')
assert hasattr(WHISPER_SMALL, 'inference_framework')
assert hasattr(WHISPER_SMALL, 'repo_id')
@patch('builtins.__import__') # Check that the models exist and have the correct structure
assert hasattr(WHISPER_TURBO, "inference_framework")
assert hasattr(WHISPER_TURBO, "repo_id")
assert hasattr(WHISPER_BASE, "inference_framework")
assert hasattr(WHISPER_BASE, "repo_id")
assert hasattr(WHISPER_SMALL, "inference_framework")
assert hasattr(WHISPER_SMALL, "repo_id")
@patch("builtins.__import__")
def test_mlx_whisper_model_initialization(self, mock_import): def test_mlx_whisper_model_initialization(self, mock_import):
"""Test MLX Whisper model initialization.""" """Test MLX Whisper model initialization."""
# Mock the mlx_whisper import # Mock the mlx_whisper import
mock_mlx_whisper = Mock() mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions( asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx", repo_id="mlx-community/whisper-tiny-mlx",
@@ -74,14 +76,14 @@ class TestMlxWhisperIntegration:
logprob_threshold=-1.0, logprob_threshold=-1.0,
compression_ratio_threshold=2.4, compression_ratio_threshold=2.4,
) )
model = _MlxWhisperModel( model = _MlxWhisperModel(
enabled=True, enabled=True,
artifacts_path=None, artifacts_path=None,
accelerator_options=accelerator_options, accelerator_options=accelerator_options,
asr_options=asr_options, asr_options=asr_options,
) )
assert model.enabled is True assert model.enabled is True
assert model.model_path == "mlx-community/whisper-tiny-mlx" assert model.model_path == "mlx-community/whisper-tiny-mlx"
assert model.language == "en" assert model.language == "en"
@@ -101,8 +103,11 @@ class TestMlxWhisperIntegration:
logprob_threshold=-1.0, logprob_threshold=-1.0,
compression_ratio_threshold=2.4, compression_ratio_threshold=2.4,
) )
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")): with patch(
"builtins.__import__",
side_effect=ImportError("No module named 'mlx_whisper'"),
):
with pytest.raises(ImportError, match="mlx-whisper is not installed"): with pytest.raises(ImportError, match="mlx-whisper is not installed"):
_MlxWhisperModel( _MlxWhisperModel(
enabled=True, enabled=True,
@@ -111,13 +116,13 @@ class TestMlxWhisperIntegration:
asr_options=asr_options, asr_options=asr_options,
) )
@patch('builtins.__import__') @patch("builtins.__import__")
def test_mlx_whisper_transcribe(self, mock_import): def test_mlx_whisper_transcribe(self, mock_import):
"""Test MLX Whisper transcription method.""" """Test MLX Whisper transcription method."""
# Mock the mlx_whisper module and its transcribe function # Mock the mlx_whisper module and its transcribe function
mock_mlx_whisper = Mock() mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper mock_import.return_value = mock_mlx_whisper
# Mock the transcribe result # Mock the transcribe result
mock_result = { mock_result = {
"segments": [ "segments": [
@@ -128,12 +133,12 @@ class TestMlxWhisperIntegration:
"words": [ "words": [
{"start": 0.0, "end": 0.5, "word": "Hello"}, {"start": 0.0, "end": 0.5, "word": "Hello"},
{"start": 0.5, "end": 1.0, "word": "world"}, {"start": 0.5, "end": 1.0, "word": "world"},
] ],
} }
] ]
} }
mock_mlx_whisper.transcribe.return_value = mock_result mock_mlx_whisper.transcribe.return_value = mock_result
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions( asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx", repo_id="mlx-community/whisper-tiny-mlx",
@@ -145,18 +150,18 @@ class TestMlxWhisperIntegration:
logprob_threshold=-1.0, logprob_threshold=-1.0,
compression_ratio_threshold=2.4, compression_ratio_threshold=2.4,
) )
model = _MlxWhisperModel( model = _MlxWhisperModel(
enabled=True, enabled=True,
artifacts_path=None, artifacts_path=None,
accelerator_options=accelerator_options, accelerator_options=accelerator_options,
asr_options=asr_options, asr_options=asr_options,
) )
# Test transcription # Test transcription
audio_path = Path("test_audio.wav") audio_path = Path("test_audio.wav")
result = model.transcribe(audio_path) result = model.transcribe(audio_path)
# Verify the result # Verify the result
assert len(result) == 1 assert len(result) == 1
assert result[0].start_time == 0.0 assert result[0].start_time == 0.0
@@ -165,7 +170,7 @@ class TestMlxWhisperIntegration:
assert len(result[0].words) == 2 assert len(result[0].words) == 2
assert result[0].words[0].text == "Hello" assert result[0].words[0].text == "Hello"
assert result[0].words[1].text == "world" assert result[0].words[1].text == "world"
# Verify mlx_whisper.transcribe was called with correct parameters # Verify mlx_whisper.transcribe was called with correct parameters
mock_mlx_whisper.transcribe.assert_called_once_with( mock_mlx_whisper.transcribe.assert_called_once_with(
str(audio_path), str(audio_path),
@@ -178,13 +183,13 @@ class TestMlxWhisperIntegration:
compression_ratio_threshold=2.4, compression_ratio_threshold=2.4,
) )
@patch('builtins.__import__') @patch("builtins.__import__")
def test_asr_pipeline_with_mlx_whisper(self, mock_import): def test_asr_pipeline_with_mlx_whisper(self, mock_import):
"""Test that AsrPipeline can be initialized with MLX Whisper options.""" """Test that AsrPipeline can be initialized with MLX Whisper options."""
# Mock the mlx_whisper import # Mock the mlx_whisper import
mock_mlx_whisper = Mock() mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS) accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions( asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx", repo_id="mlx-community/whisper-tiny-mlx",
@@ -200,7 +205,7 @@ class TestMlxWhisperIntegration:
asr_options=asr_options, asr_options=asr_options,
accelerator_options=accelerator_options, accelerator_options=accelerator_options,
) )
pipeline = AsrPipeline(pipeline_options) pipeline = AsrPipeline(pipeline_options)
assert isinstance(pipeline._model, _MlxWhisperModel) assert isinstance(pipeline._model, _MlxWhisperModel)
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx" assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"