fix pre-commit checks and added proper type safety

This commit is contained in:
Ken Steele
2025-10-02 04:53:49 -07:00
parent 94803317a3
commit 21905e8ace
7 changed files with 135 additions and 94 deletions

View File

@@ -615,7 +615,17 @@ def convert( # noqa: C901
# Auto-detect pipeline based on input file formats # Auto-detect pipeline based on input file formats
if pipeline == ProcessingPipeline.STANDARD: if pipeline == ProcessingPipeline.STANDARD:
# Check if any input files are audio files by extension # Check if any input files are audio files by extension
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'} audio_extensions = {
".mp3",
".wav",
".m4a",
".aac",
".ogg",
".flac",
".mp4",
".avi",
".mov",
}
for path in input_doc_paths: for path in input_doc_paths:
if path.suffix.lower() in audio_extensions: if path.suffix.lower() in audio_extensions:
pipeline = ProcessingPipeline.ASR pipeline = ProcessingPipeline.ASR

View File

@@ -10,13 +10,14 @@ from docling.datamodel.pipeline_options_asr_model import (
# AsrResponseFormat, # AsrResponseFormat,
# ApiAsrOptions, # ApiAsrOptions,
InferenceAsrFramework, InferenceAsrFramework,
InlineAsrNativeWhisperOptions,
InlineAsrMlxWhisperOptions, InlineAsrMlxWhisperOptions,
InlineAsrNativeWhisperOptions,
TransformersModelType, TransformersModelType,
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def _get_whisper_tiny_model(): def _get_whisper_tiny_model():
""" """
Get the best Whisper Tiny model for the current hardware. Get the best Whisper Tiny model for the current hardware.
@@ -27,6 +28,7 @@ def _get_whisper_tiny_model():
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
@@ -34,6 +36,7 @@ def _get_whisper_tiny_model():
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
@@ -66,6 +69,7 @@ def _get_whisper_tiny_model():
# Create the model instance # Create the model instance
WHISPER_TINY = _get_whisper_tiny_model() WHISPER_TINY = _get_whisper_tiny_model()
def _get_whisper_small_model(): def _get_whisper_small_model():
""" """
Get the best Whisper Small model for the current hardware. Get the best Whisper Small model for the current hardware.
@@ -76,6 +80,7 @@ def _get_whisper_small_model():
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
@@ -83,6 +88,7 @@ def _get_whisper_small_model():
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
@@ -115,6 +121,7 @@ def _get_whisper_small_model():
# Create the model instance # Create the model instance
WHISPER_SMALL = _get_whisper_small_model() WHISPER_SMALL = _get_whisper_small_model()
def _get_whisper_medium_model(): def _get_whisper_medium_model():
""" """
Get the best Whisper Medium model for the current hardware. Get the best Whisper Medium model for the current hardware.
@@ -125,6 +132,7 @@ def _get_whisper_medium_model():
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
@@ -132,6 +140,7 @@ def _get_whisper_medium_model():
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
@@ -164,6 +173,7 @@ def _get_whisper_medium_model():
# Create the model instance # Create the model instance
WHISPER_MEDIUM = _get_whisper_medium_model() WHISPER_MEDIUM = _get_whisper_medium_model()
def _get_whisper_base_model(): def _get_whisper_base_model():
""" """
Get the best Whisper Base model for the current hardware. Get the best Whisper Base model for the current hardware.
@@ -174,6 +184,7 @@ def _get_whisper_base_model():
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
@@ -181,6 +192,7 @@ def _get_whisper_base_model():
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
@@ -213,6 +225,7 @@ def _get_whisper_base_model():
# Create the model instance # Create the model instance
WHISPER_BASE = _get_whisper_base_model() WHISPER_BASE = _get_whisper_base_model()
def _get_whisper_large_model(): def _get_whisper_large_model():
""" """
Get the best Whisper Large model for the current hardware. Get the best Whisper Large model for the current hardware.
@@ -223,6 +236,7 @@ def _get_whisper_large_model():
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
@@ -230,6 +244,7 @@ def _get_whisper_large_model():
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False
@@ -262,6 +277,7 @@ def _get_whisper_large_model():
# Create the model instance # Create the model instance
WHISPER_LARGE = _get_whisper_large_model() WHISPER_LARGE = _get_whisper_large_model()
def _get_whisper_turbo_model(): def _get_whisper_turbo_model():
""" """
Get the best Whisper Turbo model for the current hardware. Get the best Whisper Turbo model for the current hardware.
@@ -272,6 +288,7 @@ def _get_whisper_turbo_model():
# Check if MPS is available (Apple Silicon) # Check if MPS is available (Apple Silicon)
try: try:
import torch import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError: except ImportError:
has_mps = False has_mps = False
@@ -279,6 +296,7 @@ def _get_whisper_turbo_model():
# Check if mlx-whisper is available # Check if mlx-whisper is available
try: try:
import mlx_whisper # type: ignore import mlx_whisper # type: ignore
has_mlx_whisper = True has_mlx_whisper = True
except ImportError: except ImportError:
has_mlx_whisper = False has_mlx_whisper = False

View File

@@ -63,6 +63,7 @@ class InlineAsrMlxWhisperOptions(InlineAsrOptions):
Uses mlx-whisper library for efficient inference on Apple Silicon devices. Uses mlx-whisper library for efficient inference on Apple Silicon devices.
""" """
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
language: str = "en" language: str = "en"

View File

@@ -4,7 +4,7 @@ import re
import tempfile import tempfile
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union, cast from typing import TYPE_CHECKING, List, Optional, Union, cast
from docling_core.types.doc import DoclingDocument, DocumentOrigin from docling_core.types.doc import DoclingDocument, DocumentOrigin
@@ -32,8 +32,8 @@ from docling.datamodel.pipeline_options import (
AsrPipelineOptions, AsrPipelineOptions,
) )
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
InlineAsrNativeWhisperOptions,
InlineAsrMlxWhisperOptions, InlineAsrMlxWhisperOptions,
InlineAsrNativeWhisperOptions,
# AsrResponseFormat, # AsrResponseFormat,
InlineAsrOptions, InlineAsrOptions,
) )
@@ -334,7 +334,7 @@ class _MlxWhisperModel:
start_time=segment.get("start"), start_time=segment.get("start"),
end_time=segment.get("end"), end_time=segment.get("end"),
text=segment.get("text", "").strip(), text=segment.get("text", "").strip(),
words=[] words=[],
) )
# Add word-level timestamps if available # Add word-level timestamps if available
@@ -359,26 +359,27 @@ class AsrPipeline(BasePipeline):
self.keep_backend = True self.keep_backend = True
self.pipeline_options: AsrPipelineOptions = pipeline_options self.pipeline_options: AsrPipelineOptions = pipeline_options
self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions): if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
asr_options: InlineAsrNativeWhisperOptions = ( native_asr_options: InlineAsrNativeWhisperOptions = (
self.pipeline_options.asr_options self.pipeline_options.asr_options
) )
self._model = _NativeWhisperModel( self._model = _NativeWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense. enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=self.artifacts_path, artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
asr_options=asr_options, asr_options=native_asr_options,
) )
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions): elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
asr_options: InlineAsrMlxWhisperOptions = ( mlx_asr_options: InlineAsrMlxWhisperOptions = (
self.pipeline_options.asr_options self.pipeline_options.asr_options
) )
self._model = _MlxWhisperModel( self._model = _MlxWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense. enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=self.artifacts_path, artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
asr_options=asr_options, asr_options=mlx_asr_options,
) )
else: else:
_log.error(f"No model support for {self.pipeline_options.asr_options}") _log.error(f"No model support for {self.pipeline_options.asr_options}")

View File

@@ -12,19 +12,19 @@ from pathlib import Path
# Add the repository root to the path so we can import docling # Add the repository root to the path so we can import docling
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import ( from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_BASE, WHISPER_BASE,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_LARGE, WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO, WHISPER_TURBO,
) )
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
from docling.datamodel.base_models import InputFormat from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import AsrPipelineOptions from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.document_converter import AudioFormatOption, DocumentConverter
from docling.pipeline.asr_pipeline import AsrPipeline from docling.pipeline.asr_pipeline import AsrPipeline
from docling.document_converter import DocumentConverter, AudioFormatOption
def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"): def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"):
@@ -51,7 +51,9 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b
} }
if model_size not in model_map: if model_size not in model_map:
raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}") raise ValueError(
f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}"
)
asr_options = model_map[model_size] asr_options = model_map[model_size]
@@ -105,10 +107,14 @@ def main():
try: try:
print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...") print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...")
print("Note: MLX optimization is automatically used on Apple Silicon when available.") print(
"Note: MLX optimization is automatically used on Apple Silicon when available."
)
print() print()
transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size) transcribed_text = transcribe_audio_with_mlx_whisper(
audio_file_path, model_size
)
print("Transcription Result:") print("Transcription Result:")
print("=" * 50) print("=" * 50)

View File

@@ -1,25 +1,27 @@
""" """
Test MLX Whisper integration for Apple Silicon ASR pipeline. Test MLX Whisper integration for Apple Silicon ASR pipeline.
""" """
import pytest
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import ( from docling.datamodel.asr_model_specs import (
WHISPER_TINY,
WHISPER_BASE, WHISPER_BASE,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_LARGE, WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO, WHISPER_TURBO,
) )
from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
InferenceAsrFramework, InferenceAsrFramework,
InlineAsrMlxWhisperOptions, InlineAsrMlxWhisperOptions,
) )
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
from docling.datamodel.pipeline_options import AsrPipelineOptions
class TestMlxWhisperIntegration: class TestMlxWhisperIntegration:
@@ -47,16 +49,16 @@ class TestMlxWhisperIntegration:
# these models would automatically use MLX # these models would automatically use MLX
# Check that the models exist and have the correct structure # Check that the models exist and have the correct structure
assert hasattr(WHISPER_TURBO, 'inference_framework') assert hasattr(WHISPER_TURBO, "inference_framework")
assert hasattr(WHISPER_TURBO, 'repo_id') assert hasattr(WHISPER_TURBO, "repo_id")
assert hasattr(WHISPER_BASE, 'inference_framework') assert hasattr(WHISPER_BASE, "inference_framework")
assert hasattr(WHISPER_BASE, 'repo_id') assert hasattr(WHISPER_BASE, "repo_id")
assert hasattr(WHISPER_SMALL, 'inference_framework') assert hasattr(WHISPER_SMALL, "inference_framework")
assert hasattr(WHISPER_SMALL, 'repo_id') assert hasattr(WHISPER_SMALL, "repo_id")
@patch('builtins.__import__') @patch("builtins.__import__")
def test_mlx_whisper_model_initialization(self, mock_import): def test_mlx_whisper_model_initialization(self, mock_import):
"""Test MLX Whisper model initialization.""" """Test MLX Whisper model initialization."""
# Mock the mlx_whisper import # Mock the mlx_whisper import
@@ -102,7 +104,10 @@ class TestMlxWhisperIntegration:
compression_ratio_threshold=2.4, compression_ratio_threshold=2.4,
) )
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")): with patch(
"builtins.__import__",
side_effect=ImportError("No module named 'mlx_whisper'"),
):
with pytest.raises(ImportError, match="mlx-whisper is not installed"): with pytest.raises(ImportError, match="mlx-whisper is not installed"):
_MlxWhisperModel( _MlxWhisperModel(
enabled=True, enabled=True,
@@ -111,7 +116,7 @@ class TestMlxWhisperIntegration:
asr_options=asr_options, asr_options=asr_options,
) )
@patch('builtins.__import__') @patch("builtins.__import__")
def test_mlx_whisper_transcribe(self, mock_import): def test_mlx_whisper_transcribe(self, mock_import):
"""Test MLX Whisper transcription method.""" """Test MLX Whisper transcription method."""
# Mock the mlx_whisper module and its transcribe function # Mock the mlx_whisper module and its transcribe function
@@ -128,7 +133,7 @@ class TestMlxWhisperIntegration:
"words": [ "words": [
{"start": 0.0, "end": 0.5, "word": "Hello"}, {"start": 0.0, "end": 0.5, "word": "Hello"},
{"start": 0.5, "end": 1.0, "word": "world"}, {"start": 0.5, "end": 1.0, "word": "world"},
] ],
} }
] ]
} }
@@ -178,7 +183,7 @@ class TestMlxWhisperIntegration:
compression_ratio_threshold=2.4, compression_ratio_threshold=2.4,
) )
@patch('builtins.__import__') @patch("builtins.__import__")
def test_asr_pipeline_with_mlx_whisper(self, mock_import): def test_asr_pipeline_with_mlx_whisper(self, mock_import):
"""Test that AsrPipeline can be initialized with MLX Whisper options.""" """Test that AsrPipeline can be initialized with MLX Whisper options."""
# Mock the mlx_whisper import # Mock the mlx_whisper import