mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
* add mlx-whisper support * added mlx-whisper example and test. update docling cli to use MLX automatically if present. * fix pre-commit checks and added proper type safety * fixed linter issue * DCO Remediation Commit for Ken Steele <ksteele@gmail.com> I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: a979a680e1dc2fee8461401335cfb5dda8cfdd98 I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 9827068382ca946fe1387ed83f747ae509fcf229 I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: ebbeb45c7dc266260e1fad6bdb54a7041f8aeed4 I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 2f6fd3cf46c8ca0bb98810191578278f1df87aa3 Signed-off-by: Ken Steele <ksteele@gmail.com> * fix unit tests and code coverage for CI * DCO Remediation Commit for Ken Steele <ksteele@gmail.com> I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 5e61bf11139a2133978db2c8d306be6289aed732 Signed-off-by: Ken Steele <ksteele@gmail.com> * fix CI example test - mlx_whisper_example.py defaults to tests/data/audio/sample_10s.mp3 if no args specified. Signed-off-by: Ken Steele <ksteele@gmail.com> * refactor: centralize audio file extensions and MIME types in base_models.py - Move audio file extensions from CLI hardcoded set to FormatToExtensions[InputFormat.AUDIO] - Add support for additional audio formats: m4a, aac, ogg, flac, mp4, avi, mov - Update FormatToMimeType mapping to include MIME types for all audio formats - Update CLI auto-detection to use centralized FormatToExtensions mapping - Add comprehensive tests for audio file auto-detection and pipeline selection - Ensure explicit pipeline choices are not overridden by auto-detection Fixes issue where only .mp3 and .wav files were processed as audio despite CLI auto-detection working for all formats. The document converter now properly recognizes all audio formats through MIME type detection. Addresses review comments: - Centralizes audio extensions in base_models.py as suggested - Maintains existing auto-detection behavior while using centralized data - Adds proper test coverage for the audio detection functionality All examples and tests pass with the new centralized approach. All audio formats (mp3, wav, m4a, aac, ogg, flac, mp4, avi, mov) now work correctly. Signed-off-by: Ken Steele <ksteele@gmail.com> * feat: address reviewer feedback - improve CLI auto-detection and add explicit model options Review feedback addressed: 1. Fix CLI auto-detection to only switch to ASR pipeline when ALL files are audio - Previously switched if ANY file was audio, now requires ALL files to be audio - Added warning for mixed file types with guidance to use --pipeline asr 2. Add explicit WHISPER_X_MLX and WHISPER_X_NATIVE model options - Users can now force specific implementations if desired - Auto-selecting models (WHISPER_BASE, etc.) still choose best for hardware - Added 12 new explicit model options: _MLX and _NATIVE variants for each size CLI now supports: - Auto-selecting: whisper_tiny, whisper_base, etc. (choose best for hardware) - Explicit MLX: whisper_tiny_mlx, whisper_base_mlx, etc. (force MLX) - Explicit Native: whisper_tiny_native, whisper_base_native, etc. (force native) Addresses reviewer comments from @dolfim-ibm Signed-off-by: Ken Steele <ksteele@gmail.com> * DCO Remediation Commit for Ken Steele <ksteele@gmail.com> I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:c60e72d2b5I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:94803317a3I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:21905e8aceI, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:96c669d155I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:8371c060eaSigned-off-by: Ken Steele <ksteele@gmail.com> * test(asr): add coverage for MLX options, pipeline helpers, and VLM prompts - tests/test_asr_mlx_whisper.py: verify explicit MLX options (framework, repo ids) - tests/test_asr_pipeline.py: cover _has_text/_determine_status and backend support with proper InputDocument/NoOpBackend wiring - tests/test_interfaces.py: add BaseVlmPageModel.formulate_prompt tests (RAW/NONE/CHAT, invalid style), with minimal InlineVlmOptions scaffold Improves reliability of ASR and VLM components by validating configuration paths and helper logic. Signed-off-by: Ken Steele <ksteele@gmail.com> * test(asr): broaden coverage for model selection, pipeline flows, and VLM prompts - tests/test_asr_mlx_whisper.py - Add MLX/native selector coverage across all Whisper sizes - Validate repo_id choices under MLX and Native paths - Cover fallback path when MPS unavailable and mlx_whisper missing - tests/test_asr_pipeline.py - Relax silent-audio assertion to accept PARTIAL_SUCCESS or SUCCESS - Force CPU native path in helper tests to avoid torch in device selection - Add language handling tests for native/MLX transcribe - Cover native run success (BytesIO) and failure (exception) branches - Cover MLX run success/failure branches with mocked transcribe - Add init path coverage with artifacts_path - tests/test_interfaces.py - Add focused VLM prompt tests (NONE/CHAT variants) Result: all tests passing with significantly improved coverage for ASR model selectors, pipeline execution paths, and VLM prompt formulation. Signed-off-by: Ken Steele <ksteele@gmail.com> * simplify ASR model settings (no pipeline detection needed) Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * clean up disk space in runners Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Ken Steele <ksteele@gmail.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""
|
|
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
|
from docling.datamodel.asr_model_specs import (
|
|
WHISPER_BASE,
|
|
WHISPER_BASE_MLX,
|
|
WHISPER_LARGE,
|
|
WHISPER_LARGE_MLX,
|
|
WHISPER_MEDIUM,
|
|
WHISPER_SMALL,
|
|
WHISPER_TINY,
|
|
WHISPER_TURBO,
|
|
)
|
|
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
|
from docling.datamodel.pipeline_options_asr_model import (
|
|
InferenceAsrFramework,
|
|
InlineAsrMlxWhisperOptions,
|
|
)
|
|
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
|
|
|
|
|
class TestMlxWhisperIntegration:
|
|
"""Test MLX Whisper model integration."""
|
|
|
|
def test_mlx_whisper_options_creation(self):
|
|
"""Test that MLX Whisper options are created correctly."""
|
|
options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
language="en",
|
|
task="transcribe",
|
|
)
|
|
|
|
assert options.inference_framework == InferenceAsrFramework.MLX
|
|
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
|
|
assert options.language == "en"
|
|
assert options.task == "transcribe"
|
|
assert options.word_timestamps is True
|
|
assert AcceleratorDevice.MPS in options.supported_devices
|
|
|
|
def test_whisper_models_auto_select_mlx(self):
|
|
"""Test that Whisper models automatically select MLX when MPS and mlx-whisper are available."""
|
|
# This test verifies that the models are correctly configured
|
|
# In a real Apple Silicon environment with mlx-whisper installed,
|
|
# these models would automatically use MLX
|
|
|
|
# Check that the models exist and have the correct structure
|
|
assert hasattr(WHISPER_TURBO, "inference_framework")
|
|
assert hasattr(WHISPER_TURBO, "repo_id")
|
|
|
|
assert hasattr(WHISPER_BASE, "inference_framework")
|
|
assert hasattr(WHISPER_BASE, "repo_id")
|
|
|
|
assert hasattr(WHISPER_SMALL, "inference_framework")
|
|
assert hasattr(WHISPER_SMALL, "repo_id")
|
|
|
|
def test_explicit_mlx_models_shape(self):
|
|
"""Explicit MLX options should have MLX framework and valid repos."""
|
|
assert WHISPER_BASE_MLX.inference_framework.name == "MLX"
|
|
assert WHISPER_LARGE_MLX.inference_framework.name == "MLX"
|
|
assert WHISPER_BASE_MLX.repo_id.startswith("mlx-community/")
|
|
|
|
def test_model_selectors_mlx_and_native_paths(self, monkeypatch):
|
|
"""Cover MLX/native selection branches in asr_model_specs getters."""
|
|
from docling.datamodel import asr_model_specs as specs
|
|
|
|
# Force MLX path
|
|
class _Mps:
|
|
def is_built(self):
|
|
return True
|
|
|
|
def is_available(self):
|
|
return True
|
|
|
|
class _Torch:
|
|
class backends:
|
|
mps = _Mps()
|
|
|
|
monkeypatch.setitem(sys.modules, "torch", _Torch())
|
|
monkeypatch.setitem(sys.modules, "mlx_whisper", object())
|
|
|
|
m_tiny = specs._get_whisper_tiny_model()
|
|
m_small = specs._get_whisper_small_model()
|
|
m_base = specs._get_whisper_base_model()
|
|
m_medium = specs._get_whisper_medium_model()
|
|
m_large = specs._get_whisper_large_model()
|
|
m_turbo = specs._get_whisper_turbo_model()
|
|
assert (
|
|
m_tiny.inference_framework == InferenceAsrFramework.MLX
|
|
and m_tiny.repo_id.startswith("mlx-community/whisper-tiny")
|
|
)
|
|
assert (
|
|
m_small.inference_framework == InferenceAsrFramework.MLX
|
|
and m_small.repo_id.startswith("mlx-community/whisper-small")
|
|
)
|
|
assert (
|
|
m_base.inference_framework == InferenceAsrFramework.MLX
|
|
and m_base.repo_id.startswith("mlx-community/whisper-base")
|
|
)
|
|
assert (
|
|
m_medium.inference_framework == InferenceAsrFramework.MLX
|
|
and "medium" in m_medium.repo_id
|
|
)
|
|
assert (
|
|
m_large.inference_framework == InferenceAsrFramework.MLX
|
|
and "large" in m_large.repo_id
|
|
)
|
|
assert (
|
|
m_turbo.inference_framework == InferenceAsrFramework.MLX
|
|
and m_turbo.repo_id.endswith("whisper-turbo")
|
|
)
|
|
|
|
# Force native path (no mlx or no mps)
|
|
if "mlx_whisper" in sys.modules:
|
|
del sys.modules["mlx_whisper"]
|
|
|
|
class _MpsOff:
|
|
def is_built(self):
|
|
return False
|
|
|
|
def is_available(self):
|
|
return False
|
|
|
|
class _TorchOff:
|
|
class backends:
|
|
mps = _MpsOff()
|
|
|
|
monkeypatch.setitem(sys.modules, "torch", _TorchOff())
|
|
n_tiny = specs._get_whisper_tiny_model()
|
|
n_small = specs._get_whisper_small_model()
|
|
n_base = specs._get_whisper_base_model()
|
|
n_medium = specs._get_whisper_medium_model()
|
|
n_large = specs._get_whisper_large_model()
|
|
n_turbo = specs._get_whisper_turbo_model()
|
|
assert (
|
|
n_tiny.inference_framework == InferenceAsrFramework.WHISPER
|
|
and n_tiny.repo_id == "tiny"
|
|
)
|
|
assert (
|
|
n_small.inference_framework == InferenceAsrFramework.WHISPER
|
|
and n_small.repo_id == "small"
|
|
)
|
|
assert (
|
|
n_base.inference_framework == InferenceAsrFramework.WHISPER
|
|
and n_base.repo_id == "base"
|
|
)
|
|
assert (
|
|
n_medium.inference_framework == InferenceAsrFramework.WHISPER
|
|
and n_medium.repo_id == "medium"
|
|
)
|
|
assert (
|
|
n_large.inference_framework == InferenceAsrFramework.WHISPER
|
|
and n_large.repo_id == "large"
|
|
)
|
|
assert (
|
|
n_turbo.inference_framework == InferenceAsrFramework.WHISPER
|
|
and n_turbo.repo_id == "turbo"
|
|
)
|
|
|
|
def test_selector_import_errors_force_native(self, monkeypatch):
|
|
"""If torch import fails, selector must return native."""
|
|
from docling.datamodel import asr_model_specs as specs
|
|
|
|
# Simulate environment where MPS is unavailable and mlx_whisper missing
|
|
class _MpsOff:
|
|
def is_built(self):
|
|
return False
|
|
|
|
def is_available(self):
|
|
return False
|
|
|
|
class _TorchOff:
|
|
class backends:
|
|
mps = _MpsOff()
|
|
|
|
monkeypatch.setitem(sys.modules, "torch", _TorchOff())
|
|
if "mlx_whisper" in sys.modules:
|
|
del sys.modules["mlx_whisper"]
|
|
|
|
model = specs._get_whisper_base_model()
|
|
assert model.inference_framework == InferenceAsrFramework.WHISPER
|
|
|
|
@patch("builtins.__import__")
|
|
def test_mlx_whisper_model_initialization(self, mock_import):
|
|
"""Test MLX Whisper model initialization."""
|
|
# Mock the mlx_whisper import
|
|
mock_mlx_whisper = Mock()
|
|
mock_import.return_value = mock_mlx_whisper
|
|
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
model = _MlxWhisperModel(
|
|
enabled=True,
|
|
artifacts_path=None,
|
|
accelerator_options=accelerator_options,
|
|
asr_options=asr_options,
|
|
)
|
|
|
|
assert model.enabled is True
|
|
assert model.model_path == "mlx-community/whisper-tiny-mlx"
|
|
assert model.language == "en"
|
|
assert model.task == "transcribe"
|
|
assert model.word_timestamps is True
|
|
|
|
def test_mlx_whisper_model_import_error(self):
|
|
"""Test that ImportError is raised when mlx-whisper is not available."""
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
with patch(
|
|
"builtins.__import__",
|
|
side_effect=ImportError("No module named 'mlx_whisper'"),
|
|
):
|
|
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
|
_MlxWhisperModel(
|
|
enabled=True,
|
|
artifacts_path=None,
|
|
accelerator_options=accelerator_options,
|
|
asr_options=asr_options,
|
|
)
|
|
|
|
@patch("builtins.__import__")
|
|
def test_mlx_whisper_transcribe(self, mock_import):
|
|
"""Test MLX Whisper transcription method."""
|
|
# Mock the mlx_whisper module and its transcribe function
|
|
mock_mlx_whisper = Mock()
|
|
mock_import.return_value = mock_mlx_whisper
|
|
|
|
# Mock the transcribe result
|
|
mock_result = {
|
|
"segments": [
|
|
{
|
|
"start": 0.0,
|
|
"end": 2.5,
|
|
"text": "Hello world",
|
|
"words": [
|
|
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
|
{"start": 0.5, "end": 1.0, "word": "world"},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
mock_mlx_whisper.transcribe.return_value = mock_result
|
|
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
model = _MlxWhisperModel(
|
|
enabled=True,
|
|
artifacts_path=None,
|
|
accelerator_options=accelerator_options,
|
|
asr_options=asr_options,
|
|
)
|
|
|
|
# Test transcription
|
|
audio_path = Path("test_audio.wav")
|
|
result = model.transcribe(audio_path)
|
|
|
|
# Verify the result
|
|
assert len(result) == 1
|
|
assert result[0].start_time == 0.0
|
|
assert result[0].end_time == 2.5
|
|
assert result[0].text == "Hello world"
|
|
assert len(result[0].words) == 2
|
|
assert result[0].words[0].text == "Hello"
|
|
assert result[0].words[1].text == "world"
|
|
|
|
# Verify mlx_whisper.transcribe was called with correct parameters
|
|
mock_mlx_whisper.transcribe.assert_called_once_with(
|
|
str(audio_path),
|
|
path_or_hf_repo="mlx-community/whisper-tiny-mlx",
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
@patch("builtins.__import__")
|
|
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
|
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
|
# Mock the mlx_whisper import
|
|
mock_mlx_whisper = Mock()
|
|
mock_import.return_value = mock_mlx_whisper
|
|
|
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
|
asr_options = InlineAsrMlxWhisperOptions(
|
|
repo_id="mlx-community/whisper-tiny-mlx",
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
pipeline_options = AsrPipelineOptions(
|
|
asr_options=asr_options,
|
|
accelerator_options=accelerator_options,
|
|
)
|
|
|
|
pipeline = AsrPipeline(pipeline_options)
|
|
assert isinstance(pipeline._model, _MlxWhisperModel)
|
|
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"
|