mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 20:58:11 +00:00
fix pre-commit checks and added proper type safety
This commit is contained in:
@@ -611,17 +611,27 @@ def convert( # noqa: C901
|
||||
ocr_options.psm = psm
|
||||
|
||||
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
||||
|
||||
|
||||
# Auto-detect pipeline based on input file formats
|
||||
if pipeline == ProcessingPipeline.STANDARD:
|
||||
# Check if any input files are audio files by extension
|
||||
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'}
|
||||
audio_extensions = {
|
||||
".mp3",
|
||||
".wav",
|
||||
".m4a",
|
||||
".aac",
|
||||
".ogg",
|
||||
".flac",
|
||||
".mp4",
|
||||
".avi",
|
||||
".mov",
|
||||
}
|
||||
for path in input_doc_paths:
|
||||
if path.suffix.lower() in audio_extensions:
|
||||
pipeline = ProcessingPipeline.ASR
|
||||
_log.info(f"Auto-detected ASR pipeline for audio file: {path}")
|
||||
break
|
||||
|
||||
|
||||
# pipeline_options: PaginatedPipelineOptions
|
||||
pipeline_options: PipelineOptions
|
||||
|
||||
|
||||
@@ -10,34 +10,37 @@ from docling.datamodel.pipeline_options_asr_model import (
|
||||
# AsrResponseFormat,
|
||||
# ApiAsrOptions,
|
||||
InferenceAsrFramework,
|
||||
InlineAsrNativeWhisperOptions,
|
||||
InlineAsrMlxWhisperOptions,
|
||||
InlineAsrNativeWhisperOptions,
|
||||
TransformersModelType,
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_whisper_tiny_model():
|
||||
"""
|
||||
Get the best Whisper Tiny model for the current hardware.
|
||||
|
||||
|
||||
Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available,
|
||||
otherwise falls back to native Whisper Tiny.
|
||||
"""
|
||||
# Check if MPS is available (Apple Silicon)
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
except ImportError:
|
||||
has_mps = False
|
||||
|
||||
|
||||
# Check if mlx-whisper is available
|
||||
try:
|
||||
import mlx_whisper # type: ignore
|
||||
|
||||
has_mlx_whisper = True
|
||||
except ImportError:
|
||||
has_mlx_whisper = False
|
||||
|
||||
|
||||
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||
if has_mps and has_mlx_whisper:
|
||||
return InlineAsrMlxWhisperOptions(
|
||||
@@ -66,27 +69,30 @@ def _get_whisper_tiny_model():
|
||||
# Create the model instance
|
||||
WHISPER_TINY = _get_whisper_tiny_model()
|
||||
|
||||
|
||||
def _get_whisper_small_model():
|
||||
"""
|
||||
Get the best Whisper Small model for the current hardware.
|
||||
|
||||
|
||||
Automatically selects MLX Whisper Small for Apple Silicon (MPS) if available,
|
||||
otherwise falls back to native Whisper Small.
|
||||
"""
|
||||
# Check if MPS is available (Apple Silicon)
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
except ImportError:
|
||||
has_mps = False
|
||||
|
||||
|
||||
# Check if mlx-whisper is available
|
||||
try:
|
||||
import mlx_whisper # type: ignore
|
||||
|
||||
has_mlx_whisper = True
|
||||
except ImportError:
|
||||
has_mlx_whisper = False
|
||||
|
||||
|
||||
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||
if has_mps and has_mlx_whisper:
|
||||
return InlineAsrMlxWhisperOptions(
|
||||
@@ -115,27 +121,30 @@ def _get_whisper_small_model():
|
||||
# Create the model instance
|
||||
WHISPER_SMALL = _get_whisper_small_model()
|
||||
|
||||
|
||||
def _get_whisper_medium_model():
|
||||
"""
|
||||
Get the best Whisper Medium model for the current hardware.
|
||||
|
||||
|
||||
Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available,
|
||||
otherwise falls back to native Whisper Medium.
|
||||
"""
|
||||
# Check if MPS is available (Apple Silicon)
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
except ImportError:
|
||||
has_mps = False
|
||||
|
||||
|
||||
# Check if mlx-whisper is available
|
||||
try:
|
||||
import mlx_whisper # type: ignore
|
||||
|
||||
has_mlx_whisper = True
|
||||
except ImportError:
|
||||
has_mlx_whisper = False
|
||||
|
||||
|
||||
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||
if has_mps and has_mlx_whisper:
|
||||
return InlineAsrMlxWhisperOptions(
|
||||
@@ -164,27 +173,30 @@ def _get_whisper_medium_model():
|
||||
# Create the model instance
|
||||
WHISPER_MEDIUM = _get_whisper_medium_model()
|
||||
|
||||
|
||||
def _get_whisper_base_model():
|
||||
"""
|
||||
Get the best Whisper Base model for the current hardware.
|
||||
|
||||
|
||||
Automatically selects MLX Whisper Base for Apple Silicon (MPS) if available,
|
||||
otherwise falls back to native Whisper Base.
|
||||
"""
|
||||
# Check if MPS is available (Apple Silicon)
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
except ImportError:
|
||||
has_mps = False
|
||||
|
||||
|
||||
# Check if mlx-whisper is available
|
||||
try:
|
||||
import mlx_whisper # type: ignore
|
||||
|
||||
has_mlx_whisper = True
|
||||
except ImportError:
|
||||
has_mlx_whisper = False
|
||||
|
||||
|
||||
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||
if has_mps and has_mlx_whisper:
|
||||
return InlineAsrMlxWhisperOptions(
|
||||
@@ -213,27 +225,30 @@ def _get_whisper_base_model():
|
||||
# Create the model instance
|
||||
WHISPER_BASE = _get_whisper_base_model()
|
||||
|
||||
|
||||
def _get_whisper_large_model():
|
||||
"""
|
||||
Get the best Whisper Large model for the current hardware.
|
||||
|
||||
|
||||
Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available,
|
||||
otherwise falls back to native Whisper Large.
|
||||
"""
|
||||
# Check if MPS is available (Apple Silicon)
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
except ImportError:
|
||||
has_mps = False
|
||||
|
||||
|
||||
# Check if mlx-whisper is available
|
||||
try:
|
||||
import mlx_whisper # type: ignore
|
||||
|
||||
has_mlx_whisper = True
|
||||
except ImportError:
|
||||
has_mlx_whisper = False
|
||||
|
||||
|
||||
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||
if has_mps and has_mlx_whisper:
|
||||
return InlineAsrMlxWhisperOptions(
|
||||
@@ -262,27 +277,30 @@ def _get_whisper_large_model():
|
||||
# Create the model instance
|
||||
WHISPER_LARGE = _get_whisper_large_model()
|
||||
|
||||
|
||||
def _get_whisper_turbo_model():
|
||||
"""
|
||||
Get the best Whisper Turbo model for the current hardware.
|
||||
|
||||
|
||||
Automatically selects MLX Whisper Turbo for Apple Silicon (MPS) if available,
|
||||
otherwise falls back to native Whisper Turbo.
|
||||
"""
|
||||
# Check if MPS is available (Apple Silicon)
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
except ImportError:
|
||||
has_mps = False
|
||||
|
||||
|
||||
# Check if mlx-whisper is available
|
||||
try:
|
||||
import mlx_whisper # type: ignore
|
||||
|
||||
has_mlx_whisper = True
|
||||
except ImportError:
|
||||
has_mlx_whisper = False
|
||||
|
||||
|
||||
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||
if has_mps and has_mlx_whisper:
|
||||
return InlineAsrMlxWhisperOptions(
|
||||
|
||||
@@ -60,9 +60,10 @@ class InlineAsrNativeWhisperOptions(InlineAsrOptions):
|
||||
class InlineAsrMlxWhisperOptions(InlineAsrOptions):
|
||||
"""
|
||||
MLX Whisper options for Apple Silicon optimization.
|
||||
|
||||
|
||||
Uses mlx-whisper library for efficient inference on Apple Silicon devices.
|
||||
"""
|
||||
|
||||
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
|
||||
|
||||
language: str = "en"
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||
|
||||
from docling_core.types.doc import DoclingDocument, DocumentOrigin
|
||||
|
||||
@@ -32,8 +32,8 @@ from docling.datamodel.pipeline_options import (
|
||||
AsrPipelineOptions,
|
||||
)
|
||||
from docling.datamodel.pipeline_options_asr_model import (
|
||||
InlineAsrNativeWhisperOptions,
|
||||
InlineAsrMlxWhisperOptions,
|
||||
InlineAsrNativeWhisperOptions,
|
||||
# AsrResponseFormat,
|
||||
InlineAsrOptions,
|
||||
)
|
||||
@@ -263,7 +263,7 @@ class _MlxWhisperModel:
|
||||
|
||||
self.model_name = asr_options.repo_id
|
||||
_log.info(f"loading _MlxWhisperModel({self.model_name})")
|
||||
|
||||
|
||||
# MLX Whisper models are loaded differently - they use HuggingFace repos
|
||||
self.model_path = self.model_name
|
||||
|
||||
@@ -308,10 +308,10 @@ class _MlxWhisperModel:
|
||||
def transcribe(self, fpath: Path) -> list[_ConversationItem]:
|
||||
"""
|
||||
Transcribe audio using MLX Whisper.
|
||||
|
||||
|
||||
Args:
|
||||
fpath: Path to audio file
|
||||
|
||||
|
||||
Returns:
|
||||
List of conversation items with timestamps
|
||||
"""
|
||||
@@ -327,16 +327,16 @@ class _MlxWhisperModel:
|
||||
)
|
||||
|
||||
convo: list[_ConversationItem] = []
|
||||
|
||||
|
||||
# MLX Whisper returns segments similar to native Whisper
|
||||
for segment in result.get("segments", []):
|
||||
item = _ConversationItem(
|
||||
start_time=segment.get("start"),
|
||||
end_time=segment.get("end"),
|
||||
text=segment.get("text", "").strip(),
|
||||
words=[]
|
||||
words=[],
|
||||
)
|
||||
|
||||
|
||||
# Add word-level timestamps if available
|
||||
if self.word_timestamps and "words" in segment:
|
||||
item.words = []
|
||||
@@ -359,26 +359,27 @@ class AsrPipeline(BasePipeline):
|
||||
self.keep_backend = True
|
||||
|
||||
self.pipeline_options: AsrPipelineOptions = pipeline_options
|
||||
self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
|
||||
|
||||
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
|
||||
asr_options: InlineAsrNativeWhisperOptions = (
|
||||
native_asr_options: InlineAsrNativeWhisperOptions = (
|
||||
self.pipeline_options.asr_options
|
||||
)
|
||||
self._model = _NativeWhisperModel(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
artifacts_path=self.artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
asr_options=asr_options,
|
||||
asr_options=native_asr_options,
|
||||
)
|
||||
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
|
||||
asr_options: InlineAsrMlxWhisperOptions = (
|
||||
mlx_asr_options: InlineAsrMlxWhisperOptions = (
|
||||
self.pipeline_options.asr_options
|
||||
)
|
||||
self._model = _MlxWhisperModel(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
artifacts_path=self.artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
asr_options=asr_options,
|
||||
asr_options=mlx_asr_options,
|
||||
)
|
||||
else:
|
||||
_log.error(f"No model support for {self.pipeline_options.asr_options}")
|
||||
|
||||
2
docs/examples/minimal_asr_pipeline.py
vendored
2
docs/examples/minimal_asr_pipeline.py
vendored
@@ -43,7 +43,7 @@ def get_asr_converter():
|
||||
implementation for your hardware:
|
||||
- MLX Whisper Turbo for Apple Silicon (M1/M2/M3) with mlx-whisper installed
|
||||
- Native Whisper Turbo as fallback
|
||||
|
||||
|
||||
You can swap in another model spec from `docling.datamodel.asr_model_specs`
|
||||
to experiment with different model sizes.
|
||||
"""
|
||||
|
||||
54
docs/examples/mlx_whisper_example.py
vendored
54
docs/examples/mlx_whisper_example.py
vendored
@@ -12,31 +12,31 @@ from pathlib import Path
|
||||
# Add the repository root to the path so we can import docling
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.asr_model_specs import (
|
||||
WHISPER_TINY,
|
||||
WHISPER_BASE,
|
||||
WHISPER_SMALL,
|
||||
WHISPER_MEDIUM,
|
||||
WHISPER_LARGE,
|
||||
WHISPER_MEDIUM,
|
||||
WHISPER_SMALL,
|
||||
WHISPER_TINY,
|
||||
WHISPER_TURBO,
|
||||
)
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||
from docling.document_converter import AudioFormatOption, DocumentConverter
|
||||
from docling.pipeline.asr_pipeline import AsrPipeline
|
||||
from docling.document_converter import DocumentConverter, AudioFormatOption
|
||||
|
||||
|
||||
def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"):
|
||||
"""
|
||||
Transcribe audio using Whisper models with automatic MLX optimization for Apple Silicon.
|
||||
|
||||
|
||||
Args:
|
||||
audio_file_path: Path to the audio file to transcribe
|
||||
model_size: Size of the Whisper model to use
|
||||
("tiny", "base", "small", "medium", "large", "turbo")
|
||||
Note: MLX optimization is automatically used on Apple Silicon when available
|
||||
|
||||
|
||||
Returns:
|
||||
The transcribed text
|
||||
"""
|
||||
@@ -49,21 +49,23 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b
|
||||
"large": WHISPER_LARGE,
|
||||
"turbo": WHISPER_TURBO,
|
||||
}
|
||||
|
||||
|
||||
if model_size not in model_map:
|
||||
raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}")
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}"
|
||||
)
|
||||
|
||||
asr_options = model_map[model_size]
|
||||
|
||||
|
||||
# Configure accelerator options for Apple Silicon
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
|
||||
|
||||
# Create pipeline options
|
||||
pipeline_options = AsrPipelineOptions(
|
||||
asr_options=asr_options,
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
|
||||
|
||||
# Create document converter with MLX Whisper configuration
|
||||
converter = DocumentConverter(
|
||||
format_options={
|
||||
@@ -73,16 +75,16 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Run transcription
|
||||
result = converter.convert(Path(audio_file_path))
|
||||
|
||||
|
||||
if result.status.value == "success":
|
||||
# Extract text from the document
|
||||
text_content = []
|
||||
for item in result.document.texts:
|
||||
text_content.append(item.text)
|
||||
|
||||
|
||||
return "\n".join(text_content)
|
||||
else:
|
||||
raise RuntimeError(f"Transcription failed: {result.status}")
|
||||
@@ -95,26 +97,30 @@ def main():
|
||||
print("Model sizes: tiny, base, small, medium, large, turbo")
|
||||
print("Example: python mlx_whisper_example.py audio.wav base")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
audio_file_path = sys.argv[1]
|
||||
model_size = sys.argv[2] if len(sys.argv) > 2 else "base"
|
||||
|
||||
|
||||
if not Path(audio_file_path).exists():
|
||||
print(f"Error: Audio file '{audio_file_path}' not found.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
try:
|
||||
print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...")
|
||||
print("Note: MLX optimization is automatically used on Apple Silicon when available.")
|
||||
print(
|
||||
"Note: MLX optimization is automatically used on Apple Silicon when available."
|
||||
)
|
||||
print()
|
||||
|
||||
transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size)
|
||||
|
||||
|
||||
transcribed_text = transcribe_audio_with_mlx_whisper(
|
||||
audio_file_path, model_size
|
||||
)
|
||||
|
||||
print("Transcription Result:")
|
||||
print("=" * 50)
|
||||
print(transcribed_text)
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Error: {e}")
|
||||
print("Please install mlx-whisper: pip install mlx-whisper")
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
"""
|
||||
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.asr_model_specs import (
|
||||
WHISPER_TINY,
|
||||
WHISPER_BASE,
|
||||
WHISPER_SMALL,
|
||||
WHISPER_MEDIUM,
|
||||
WHISPER_LARGE,
|
||||
WHISPER_MEDIUM,
|
||||
WHISPER_SMALL,
|
||||
WHISPER_TINY,
|
||||
WHISPER_TURBO,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||
from docling.datamodel.pipeline_options_asr_model import (
|
||||
InferenceAsrFramework,
|
||||
InlineAsrMlxWhisperOptions,
|
||||
)
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
||||
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||
|
||||
|
||||
class TestMlxWhisperIntegration:
|
||||
@@ -32,7 +34,7 @@ class TestMlxWhisperIntegration:
|
||||
language="en",
|
||||
task="transcribe",
|
||||
)
|
||||
|
||||
|
||||
assert options.inference_framework == InferenceAsrFramework.MLX
|
||||
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
|
||||
assert options.language == "en"
|
||||
@@ -45,24 +47,24 @@ class TestMlxWhisperIntegration:
|
||||
# This test verifies that the models are correctly configured
|
||||
# In a real Apple Silicon environment with mlx-whisper installed,
|
||||
# these models would automatically use MLX
|
||||
|
||||
# Check that the models exist and have the correct structure
|
||||
assert hasattr(WHISPER_TURBO, 'inference_framework')
|
||||
assert hasattr(WHISPER_TURBO, 'repo_id')
|
||||
|
||||
assert hasattr(WHISPER_BASE, 'inference_framework')
|
||||
assert hasattr(WHISPER_BASE, 'repo_id')
|
||||
|
||||
assert hasattr(WHISPER_SMALL, 'inference_framework')
|
||||
assert hasattr(WHISPER_SMALL, 'repo_id')
|
||||
|
||||
@patch('builtins.__import__')
|
||||
# Check that the models exist and have the correct structure
|
||||
assert hasattr(WHISPER_TURBO, "inference_framework")
|
||||
assert hasattr(WHISPER_TURBO, "repo_id")
|
||||
|
||||
assert hasattr(WHISPER_BASE, "inference_framework")
|
||||
assert hasattr(WHISPER_BASE, "repo_id")
|
||||
|
||||
assert hasattr(WHISPER_SMALL, "inference_framework")
|
||||
assert hasattr(WHISPER_SMALL, "repo_id")
|
||||
|
||||
@patch("builtins.__import__")
|
||||
def test_mlx_whisper_model_initialization(self, mock_import):
|
||||
"""Test MLX Whisper model initialization."""
|
||||
# Mock the mlx_whisper import
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
@@ -74,14 +76,14 @@ class TestMlxWhisperIntegration:
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
|
||||
model = _MlxWhisperModel(
|
||||
enabled=True,
|
||||
artifacts_path=None,
|
||||
accelerator_options=accelerator_options,
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
|
||||
assert model.enabled is True
|
||||
assert model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||
assert model.language == "en"
|
||||
@@ -101,8 +103,11 @@ class TestMlxWhisperIntegration:
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
|
||||
|
||||
with patch(
|
||||
"builtins.__import__",
|
||||
side_effect=ImportError("No module named 'mlx_whisper'"),
|
||||
):
|
||||
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
||||
_MlxWhisperModel(
|
||||
enabled=True,
|
||||
@@ -111,13 +116,13 @@ class TestMlxWhisperIntegration:
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
@patch('builtins.__import__')
|
||||
@patch("builtins.__import__")
|
||||
def test_mlx_whisper_transcribe(self, mock_import):
|
||||
"""Test MLX Whisper transcription method."""
|
||||
# Mock the mlx_whisper module and its transcribe function
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
|
||||
# Mock the transcribe result
|
||||
mock_result = {
|
||||
"segments": [
|
||||
@@ -128,12 +133,12 @@ class TestMlxWhisperIntegration:
|
||||
"words": [
|
||||
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
||||
{"start": 0.5, "end": 1.0, "word": "world"},
|
||||
]
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_mlx_whisper.transcribe.return_value = mock_result
|
||||
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
@@ -145,18 +150,18 @@ class TestMlxWhisperIntegration:
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
|
||||
model = _MlxWhisperModel(
|
||||
enabled=True,
|
||||
artifacts_path=None,
|
||||
accelerator_options=accelerator_options,
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
|
||||
# Test transcription
|
||||
audio_path = Path("test_audio.wav")
|
||||
result = model.transcribe(audio_path)
|
||||
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert result[0].start_time == 0.0
|
||||
@@ -165,7 +170,7 @@ class TestMlxWhisperIntegration:
|
||||
assert len(result[0].words) == 2
|
||||
assert result[0].words[0].text == "Hello"
|
||||
assert result[0].words[1].text == "world"
|
||||
|
||||
|
||||
# Verify mlx_whisper.transcribe was called with correct parameters
|
||||
mock_mlx_whisper.transcribe.assert_called_once_with(
|
||||
str(audio_path),
|
||||
@@ -178,13 +183,13 @@ class TestMlxWhisperIntegration:
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
@patch('builtins.__import__')
|
||||
@patch("builtins.__import__")
|
||||
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
||||
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
||||
# Mock the mlx_whisper import
|
||||
mock_mlx_whisper = Mock()
|
||||
mock_import.return_value = mock_mlx_whisper
|
||||
|
||||
|
||||
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||
asr_options = InlineAsrMlxWhisperOptions(
|
||||
repo_id="mlx-community/whisper-tiny-mlx",
|
||||
@@ -200,7 +205,7 @@ class TestMlxWhisperIntegration:
|
||||
asr_options=asr_options,
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
|
||||
|
||||
pipeline = AsrPipeline(pipeline_options)
|
||||
assert isinstance(pipeline._model, _MlxWhisperModel)
|
||||
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||
|
||||
Reference in New Issue
Block a user