mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
207 lines
7.7 KiB
Python
207 lines
7.7 KiB
Python
"""
|
|
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
|
"""
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
from docling.datamodel.asr_model_specs import (
|
|
WHISPER_TINY,
|
|
WHISPER_BASE,
|
|
WHISPER_SMALL,
|
|
WHISPER_MEDIUM,
|
|
WHISPER_LARGE,
|
|
WHISPER_TURBO,
|
|
)
|
|
from docling.datamodel.pipeline_options_asr_model import (
|
|
InferenceAsrFramework,
|
|
InlineAsrMlxWhisperOptions,
|
|
)
|
|
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
|
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
|
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
|
|
|
|
|
class TestMlxWhisperIntegration:
|
|
"""Test MLX Whisper model integration."""
|
|
|
|
def test_mlx_whisper_options_creation(self):
|
|
"""Test that MLX Whisper options are created correctly."""
|
|
options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
language="en",
|
|
task="transcribe",
|
|
)
|
|
|
|
assert options.inference_framework == InferenceAsrFramework.MLX
|
|
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
|
|
assert options.language == "en"
|
|
assert options.task == "transcribe"
|
|
assert options.word_timestamps is True
|
|
assert AcceleratorDevice.MPS in options.supported_devices
|
|
|
|
def test_whisper_models_auto_select_mlx(self):
|
|
"""Test that Whisper models automatically select MLX when MPS and mlx-whisper are available."""
|
|
# This test verifies that the models are correctly configured
|
|
# In a real Apple Silicon environment with mlx-whisper installed,
|
|
# these models would automatically use MLX
|
|
|
|
# Check that the models exist and have the correct structure
|
|
assert hasattr(WHISPER_TURBO, 'inference_framework')
|
|
assert hasattr(WHISPER_TURBO, 'repo_id')
|
|
|
|
assert hasattr(WHISPER_BASE, 'inference_framework')
|
|
assert hasattr(WHISPER_BASE, 'repo_id')
|
|
|
|
assert hasattr(WHISPER_SMALL, 'inference_framework')
|
|
assert hasattr(WHISPER_SMALL, 'repo_id')
|
|
|
|
@patch('builtins.__import__')
|
|
def test_mlx_whisper_model_initialization(self, mock_import):
|
|
"""Test MLX Whisper model initialization."""
|
|
# Mock the mlx_whisper import
|
|
mock_mlx_whisper = Mock()
|
|
mock_import.return_value = mock_mlx_whisper
|
|
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
model = _MlxWhisperModel(
|
|
enabled=True,
|
|
artifacts_path=None,
|
|
accelerator_options=accelerator_options,
|
|
asr_options=asr_options,
|
|
)
|
|
|
|
assert model.enabled is True
|
|
assert model.model_path == "mlx-community/whisper-tiny-mlx"
|
|
assert model.language == "en"
|
|
assert model.task == "transcribe"
|
|
assert model.word_timestamps is True
|
|
|
|
def test_mlx_whisper_model_import_error(self):
|
|
"""Test that ImportError is raised when mlx-whisper is not available."""
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
|
|
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
|
_MlxWhisperModel(
|
|
enabled=True,
|
|
artifacts_path=None,
|
|
accelerator_options=accelerator_options,
|
|
asr_options=asr_options,
|
|
)
|
|
|
|
@patch('builtins.__import__')
|
|
def test_mlx_whisper_transcribe(self, mock_import):
|
|
"""Test MLX Whisper transcription method."""
|
|
# Mock the mlx_whisper module and its transcribe function
|
|
mock_mlx_whisper = Mock()
|
|
mock_import.return_value = mock_mlx_whisper
|
|
|
|
# Mock the transcribe result
|
|
mock_result = {
|
|
"segments": [
|
|
{
|
|
"start": 0.0,
|
|
"end": 2.5,
|
|
"text": "Hello world",
|
|
"words": [
|
|
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
|
{"start": 0.5, "end": 1.0, "word": "world"},
|
|
]
|
|
}
|
|
]
|
|
}
|
|
mock_mlx_whisper.transcribe.return_value = mock_result
|
|
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
model = _MlxWhisperModel(
|
|
enabled=True,
|
|
artifacts_path=None,
|
|
accelerator_options=accelerator_options,
|
|
asr_options=asr_options,
|
|
)
|
|
|
|
# Test transcription
|
|
audio_path = Path("test_audio.wav")
|
|
result = model.transcribe(audio_path)
|
|
|
|
# Verify the result
|
|
assert len(result) == 1
|
|
assert result[0].start_time == 0.0
|
|
assert result[0].end_time == 2.5
|
|
assert result[0].text == "Hello world"
|
|
assert len(result[0].words) == 2
|
|
assert result[0].words[0].text == "Hello"
|
|
assert result[0].words[1].text == "world"
|
|
|
|
# Verify mlx_whisper.transcribe was called with correct parameters
|
|
mock_mlx_whisper.transcribe.assert_called_once_with(
|
|
str(audio_path),
|
|
path_or_hf_repo="mlx-community/whisper-tiny-mlx",
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
@patch('builtins.__import__')
|
|
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
|
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
|
# Mock the mlx_whisper import
|
|
mock_mlx_whisper = Mock()
|
|
mock_import.return_value = mock_mlx_whisper
|
|
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
pipeline_options = AsrPipelineOptions(
|
|
asr_options=asr_options,
|
|
accelerator_options=accelerator_options,
|
|
)
|
|
|
|
pipeline = AsrPipeline(pipeline_options)
|
|
assert isinstance(pipeline._model, _MlxWhisperModel)
|
|
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"
|