mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-09 13:18:24 +00:00
fix pre-commit checks and added proper type safety
This commit is contained in:
@@ -1,25 +1,27 @@
|
||||
"""
|
||||
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.asr_model_specs import (
|
||||
WHISPER_TINY,
|
||||
WHISPER_BASE,
|
||||
WHISPER_SMALL,
|
||||
WHISPER_MEDIUM,
|
||||
WHISPER_LARGE,
|
||||
WHISPER_MEDIUM,
|
||||
WHISPER_SMALL,
|
||||
WHISPER_TINY,
|
||||
WHISPER_TURBO,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||
from docling.datamodel.pipeline_options_asr_model import (
|
||||
InferenceAsrFramework,
|
||||
InlineAsrMlxWhisperOptions,
|
||||
)
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
||||
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||
|
||||
|
||||
class TestMlxWhisperIntegration:
|
||||
@@ -32,7 +34,7 @@ class TestMlxWhisperIntegration:
|
||||
language="en",
|
||||
task="transcribe",
|
||||
)
|
||||
|
||||
|
||||
assert options.inference_framework == InferenceAsrFramework.MLX
|
||||
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
|
||||
assert options.language == "en"
|
||||
@@ -45,24 +47,24 @@ class TestMlxWhisperIntegration:
|
||||
# This test verifies that the models are correctly configured
|
||||
# In a real Apple Silicon environment with mlx-whisper installed,
|
||||
# these models would automatically use MLX
|
||||
|
||||
# Check that the models exist and have the correct structure
|
||||
assert hasattr(WHISPER_TURBO, 'inference_framework')
|
||||
assert hasattr(WHISPER_TURBO, 'repo_id')
|
||||
|
||||
assert hasattr(WHISPER_BASE, 'inference_framework')
|
||||
assert hasattr(WHISPER_BASE, 'repo_id')
|
||||
|
||||
assert hasattr(WHISPER_SMALL, 'inference_framework')
|
||||
assert hasattr(WHISPER_SMALL, 'repo_id')
|
||||
|
||||
@patch('builtins.__import__')
|
||||
# Check that the models exist and have the correct structure
|
||||
assert hasattr(WHISPER_TURBO, "inference_framework")
|
||||
assert hasattr(WHISPER_TURBO, "repo_id")
|
||||
|
||||
assert hasattr(WHISPER_BASE, "inference_framework")
|
||||
assert hasattr(WHISPER_BASE, "repo_id")
|
||||
|
||||
assert hasattr(WHISPER_SMALL, "inference_framework")
|
||||
assert hasattr(WHISPER_SMALL, "repo_id")
|
||||
|
||||
@patch("builtins.__import__")
|
||||
def test_mlx_whisper_model_initialization(self, mock_import):
|
||||
"""Test MLX Whisper model initialization."""
|
||||
# Mock the mlx_whisper import
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
@@ -74,14 +76,14 @@ class TestMlxWhisperIntegration:
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
|
||||
model = _MlxWhisperModel(
|
||||
enabled=True,
|
||||
artifacts_path=None,
|
||||
accelerator_options=accelerator_options,
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
|
||||
assert model.enabled is True
|
||||
assert model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||
assert model.language == "en"
|
||||
@@ -101,8 +103,11 @@ class TestMlxWhisperIntegration:
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
|
||||
|
||||
with patch(
|
||||
"builtins.__import__",
|
||||
side_effect=ImportError("No module named 'mlx_whisper'"),
|
||||
):
|
||||
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
||||
_MlxWhisperModel(
|
||||
enabled=True,
|
||||
@@ -111,13 +116,13 @@ class TestMlxWhisperIntegration:
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
@patch('builtins.__import__')
|
||||
@patch("builtins.__import__")
|
||||
def test_mlx_whisper_transcribe(self, mock_import):
|
||||
"""Test MLX Whisper transcription method."""
|
||||
# Mock the mlx_whisper module and its transcribe function
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
|
||||
# Mock the transcribe result
|
||||
mock_result = {
|
||||
"segments": [
|
||||
@@ -128,12 +133,12 @@ class TestMlxWhisperIntegration:
|
||||
"words": [
|
||||
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
||||
{"start": 0.5, "end": 1.0, "word": "world"},
|
||||
]
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_mlx_whisper.transcribe.return_value = mock_result
|
||||
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
@@ -145,18 +150,18 @@ class TestMlxWhisperIntegration:
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
|
||||
model = _MlxWhisperModel(
|
||||
enabled=True,
|
||||
artifacts_path=None,
|
||||
accelerator_options=accelerator_options,
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
|
||||
# Test transcription
|
||||
audio_path = Path("test_audio.wav")
|
||||
result = model.transcribe(audio_path)
|
||||
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert result[0].start_time == 0.0
|
||||
@@ -165,7 +170,7 @@ class TestMlxWhisperIntegration:
|
||||
assert len(result[0].words) == 2
|
||||
assert result[0].words[0].text == "Hello"
|
||||
assert result[0].words[1].text == "world"
|
||||
|
||||
|
||||
# Verify mlx_whisper.transcribe was called with correct parameters
|
||||
mock_mlx_whisper.transcribe.assert_called_once_with(
|
||||
str(audio_path),
|
||||
@@ -178,13 +183,13 @@ class TestMlxWhisperIntegration:
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
@patch('builtins.__import__')
|
||||
@patch("builtins.__import__")
|
||||
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
||||
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
||||
# Mock the mlx_whisper import
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
@@ -200,7 +205,7 @@ class TestMlxWhisperIntegration:
|
||||
asr_options=asr_options,
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
|
||||
|
||||
pipeline = AsrPipeline(pipeline_options)
|
||||
assert isinstance(pipeline._model, _MlxWhisperModel)
|
||||
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||
|
||||
Reference in New Issue
Block a user