mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-09 13:18:24 +00:00
added mlx-whisper example and test. update docling cli to use MLX automatically if present.
This commit is contained in:
206
tests/test_asr_mlx_whisper.py
Normal file
206
tests/test_asr_mlx_whisper.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from docling.datamodel.asr_model_specs import (
|
||||
WHISPER_TINY,
|
||||
WHISPER_BASE,
|
||||
WHISPER_SMALL,
|
||||
WHISPER_MEDIUM,
|
||||
WHISPER_LARGE,
|
||||
WHISPER_TURBO,
|
||||
)
|
||||
from docling.datamodel.pipeline_options_asr_model import (
|
||||
InferenceAsrFramework,
|
||||
InlineAsrMlxWhisperOptions,
|
||||
)
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
||||
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||
|
||||
|
||||
class TestMlxWhisperIntegration:
|
||||
"""Test MLX Whisper model integration."""
|
||||
|
||||
def test_mlx_whisper_options_creation(self):
|
||||
"""Test that MLX Whisper options are created correctly."""
|
||||
options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
language="en",
|
||||
task="transcribe",
|
||||
)
|
||||
|
||||
assert options.inference_framework == InferenceAsrFramework.MLX
|
||||
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
|
||||
assert options.language == "en"
|
||||
assert options.task == "transcribe"
|
||||
assert options.word_timestamps is True
|
||||
assert AcceleratorDevice.MPS in options.supported_devices
|
||||
|
||||
def test_whisper_models_auto_select_mlx(self):
|
||||
"""Test that Whisper models automatically select MLX when MPS and mlx-whisper are available."""
|
||||
# This test verifies that the models are correctly configured
|
||||
# In a real Apple Silicon environment with mlx-whisper installed,
|
||||
# these models would automatically use MLX
|
||||
|
||||
# Check that the models exist and have the correct structure
|
||||
assert hasattr(WHISPER_TURBO, 'inference_framework')
|
||||
assert hasattr(WHISPER_TURBO, 'repo_id')
|
||||
|
||||
assert hasattr(WHISPER_BASE, 'inference_framework')
|
||||
assert hasattr(WHISPER_BASE, 'repo_id')
|
||||
|
||||
assert hasattr(WHISPER_SMALL, 'inference_framework')
|
||||
assert hasattr(WHISPER_SMALL, 'repo_id')
|
||||
|
||||
@patch('builtins.__import__')
|
||||
def test_mlx_whisper_model_initialization(self, mock_import):
|
||||
"""Test MLX Whisper model initialization."""
|
||||
# Mock the mlx_whisper import
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
inference_framework=InferenceAsrFramework.MLX,
|
||||
language="en",
|
||||
task="transcribe",
|
||||
word_timestamps=True,
|
||||
no_speech_threshold=0.6,
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
model = _MlxWhisperModel(
|
||||
enabled=True,
|
||||
artifacts_path=None,
|
||||
accelerator_options=accelerator_options,
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
assert model.enabled is True
|
||||
assert model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||
assert model.language == "en"
|
||||
assert model.task == "transcribe"
|
||||
assert model.word_timestamps is True
|
||||
|
||||
def test_mlx_whisper_model_import_error(self):
|
||||
"""Test that ImportError is raised when mlx-whisper is not available."""
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
inference_framework=InferenceAsrFramework.MLX,
|
||||
language="en",
|
||||
task="transcribe",
|
||||
word_timestamps=True,
|
||||
no_speech_threshold=0.6,
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
|
||||
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
||||
_MlxWhisperModel(
|
||||
enabled=True,
|
||||
artifacts_path=None,
|
||||
accelerator_options=accelerator_options,
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
@patch('builtins.__import__')
|
||||
def test_mlx_whisper_transcribe(self, mock_import):
|
||||
"""Test MLX Whisper transcription method."""
|
||||
# Mock the mlx_whisper module and its transcribe function
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
# Mock the transcribe result
|
||||
mock_result = {
|
||||
"segments": [
|
||||
{
|
||||
"start": 0.0,
|
||||
"end": 2.5,
|
||||
"text": "Hello world",
|
||||
"words": [
|
||||
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
||||
{"start": 0.5, "end": 1.0, "word": "world"},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_mlx_whisper.transcribe.return_value = mock_result
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
inference_framework=InferenceAsrFramework.MLX,
|
||||
language="en",
|
||||
task="transcribe",
|
||||
word_timestamps=True,
|
||||
no_speech_threshold=0.6,
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
model = _MlxWhisperModel(
|
||||
enabled=True,
|
||||
artifacts_path=None,
|
||||
accelerator_options=accelerator_options,
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
# Test transcription
|
||||
audio_path = Path("test_audio.wav")
|
||||
result = model.transcribe(audio_path)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert result[0].start_time == 0.0
|
||||
assert result[0].end_time == 2.5
|
||||
assert result[0].text == "Hello world"
|
||||
assert len(result[0].words) == 2
|
||||
assert result[0].words[0].text == "Hello"
|
||||
assert result[0].words[1].text == "world"
|
||||
|
||||
# Verify mlx_whisper.transcribe was called with correct parameters
|
||||
mock_mlx_whisper.transcribe.assert_called_once_with(
|
||||
str(audio_path),
|
||||
path_or_hf_repo="mlx-community/whisper-tiny-mlx",
|
||||
language="en",
|
||||
task="transcribe",
|
||||
word_timestamps=True,
|
||||
no_speech_threshold=0.6,
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
@patch('builtins.__import__')
|
||||
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
||||
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
||||
# Mock the mlx_whisper import
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
inference_framework=InferenceAsrFramework.MLX,
|
||||
language="en",
|
||||
task="transcribe",
|
||||
word_timestamps=True,
|
||||
no_speech_threshold=0.6,
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
pipeline_options = AsrPipelineOptions(
|
||||
asr_options=asr_options,
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
|
||||
pipeline = AsrPipeline(pipeline_options)
|
||||
assert isinstance(pipeline._model, _MlxWhisperModel)
|
||||
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||
Reference in New Issue
Block a user