fix pre-commit checks and added proper type safety

This commit is contained in:
Ken Steele
2025-10-02 04:53:49 -07:00
parent 94803317a3
commit 21905e8ace
7 changed files with 135 additions and 94 deletions

View File

@@ -1,25 +1,27 @@
"""
Test MLX Whisper integration for Apple Silicon ASR pipeline.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
import pytest
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_BASE,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO,
)
from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.datamodel.pipeline_options_asr_model import (
InferenceAsrFramework,
InlineAsrMlxWhisperOptions,
)
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
from docling.datamodel.pipeline_options import AsrPipelineOptions
class TestMlxWhisperIntegration:
@@ -32,7 +34,7 @@ class TestMlxWhisperIntegration:
language="en",
task="transcribe",
)
assert options.inference_framework == InferenceAsrFramework.MLX
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
assert options.language == "en"
@@ -45,24 +47,24 @@ class TestMlxWhisperIntegration:
# This test verifies that the models are correctly configured
# In a real Apple Silicon environment with mlx-whisper installed,
# these models would automatically use MLX
# Check that the models exist and have the correct structure
assert hasattr(WHISPER_TURBO, 'inference_framework')
assert hasattr(WHISPER_TURBO, 'repo_id')
assert hasattr(WHISPER_BASE, 'inference_framework')
assert hasattr(WHISPER_BASE, 'repo_id')
assert hasattr(WHISPER_SMALL, 'inference_framework')
assert hasattr(WHISPER_SMALL, 'repo_id')
@patch('builtins.__import__')
# Check that the models exist and have the correct structure
assert hasattr(WHISPER_TURBO, "inference_framework")
assert hasattr(WHISPER_TURBO, "repo_id")
assert hasattr(WHISPER_BASE, "inference_framework")
assert hasattr(WHISPER_BASE, "repo_id")
assert hasattr(WHISPER_SMALL, "inference_framework")
assert hasattr(WHISPER_SMALL, "repo_id")
@patch("builtins.__import__")
def test_mlx_whisper_model_initialization(self, mock_import):
"""Test MLX Whisper model initialization."""
# Mock the mlx_whisper import
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
@@ -74,14 +76,14 @@ class TestMlxWhisperIntegration:
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
model = _MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
assert model.enabled is True
assert model.model_path == "mlx-community/whisper-tiny-mlx"
assert model.language == "en"
@@ -101,8 +103,11 @@ class TestMlxWhisperIntegration:
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
with patch(
"builtins.__import__",
side_effect=ImportError("No module named 'mlx_whisper'"),
):
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
_MlxWhisperModel(
enabled=True,
@@ -111,13 +116,13 @@ class TestMlxWhisperIntegration:
asr_options=asr_options,
)
@patch('builtins.__import__')
@patch("builtins.__import__")
def test_mlx_whisper_transcribe(self, mock_import):
"""Test MLX Whisper transcription method."""
# Mock the mlx_whisper module and its transcribe function
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
# Mock the transcribe result
mock_result = {
"segments": [
@@ -128,12 +133,12 @@ class TestMlxWhisperIntegration:
"words": [
{"start": 0.0, "end": 0.5, "word": "Hello"},
{"start": 0.5, "end": 1.0, "word": "world"},
]
],
}
]
}
mock_mlx_whisper.transcribe.return_value = mock_result
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
@@ -145,18 +150,18 @@ class TestMlxWhisperIntegration:
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
model = _MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
# Test transcription
audio_path = Path("test_audio.wav")
result = model.transcribe(audio_path)
# Verify the result
assert len(result) == 1
assert result[0].start_time == 0.0
@@ -165,7 +170,7 @@ class TestMlxWhisperIntegration:
assert len(result[0].words) == 2
assert result[0].words[0].text == "Hello"
assert result[0].words[1].text == "world"
# Verify mlx_whisper.transcribe was called with correct parameters
mock_mlx_whisper.transcribe.assert_called_once_with(
str(audio_path),
@@ -178,13 +183,13 @@ class TestMlxWhisperIntegration:
compression_ratio_threshold=2.4,
)
@patch('builtins.__import__')
@patch("builtins.__import__")
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
# Mock the mlx_whisper import
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
@@ -200,7 +205,7 @@ class TestMlxWhisperIntegration:
asr_options=asr_options,
accelerator_options=accelerator_options,
)
pipeline = AsrPipeline(pipeline_options)
assert isinstance(pipeline._model, _MlxWhisperModel)
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"