mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
added mlx-whisper example and test. update docling cli to use MLX automatically if present.
This commit is contained in:
@@ -611,6 +611,17 @@ def convert( # noqa: C901
|
|||||||
ocr_options.psm = psm
|
ocr_options.psm = psm
|
||||||
|
|
||||||
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
||||||
|
|
||||||
|
# Auto-detect pipeline based on input file formats
|
||||||
|
if pipeline == ProcessingPipeline.STANDARD:
|
||||||
|
# Check if any input files are audio files by extension
|
||||||
|
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'}
|
||||||
|
for path in input_doc_paths:
|
||||||
|
if path.suffix.lower() in audio_extensions:
|
||||||
|
pipeline = ProcessingPipeline.ASR
|
||||||
|
_log.info(f"Auto-detected ASR pipeline for audio file: {path}")
|
||||||
|
break
|
||||||
|
|
||||||
# pipeline_options: PaginatedPipelineOptions
|
# pipeline_options: PaginatedPipelineOptions
|
||||||
pipeline_options: PipelineOptions
|
pipeline_options: PipelineOptions
|
||||||
|
|
||||||
@@ -749,6 +760,10 @@ def convert( # noqa: C901
|
|||||||
|
|
||||||
elif pipeline == ProcessingPipeline.ASR:
|
elif pipeline == ProcessingPipeline.ASR:
|
||||||
pipeline_options = AsrPipelineOptions(
|
pipeline_options = AsrPipelineOptions(
|
||||||
|
accelerator_options=AcceleratorOptions(
|
||||||
|
device=device,
|
||||||
|
num_threads=num_threads,
|
||||||
|
),
|
||||||
# enable_remote_services=enable_remote_services,
|
# enable_remote_services=enable_remote_services,
|
||||||
# artifacts_path = artifacts_path
|
# artifacts_path = artifacts_path
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,41 @@ from docling.datamodel.pipeline_options_asr_model import (
|
|||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
WHISPER_TINY = InlineAsrNativeWhisperOptions(
|
def _get_whisper_tiny_model():
|
||||||
|
"""
|
||||||
|
Get the best Whisper Tiny model for the current hardware.
|
||||||
|
|
||||||
|
Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available,
|
||||||
|
otherwise falls back to native Whisper Tiny.
|
||||||
|
"""
|
||||||
|
# Check if MPS is available (Apple Silicon)
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
|
except ImportError:
|
||||||
|
has_mps = False
|
||||||
|
|
||||||
|
# Check if mlx-whisper is available
|
||||||
|
try:
|
||||||
|
import mlx_whisper # type: ignore
|
||||||
|
has_mlx_whisper = True
|
||||||
|
except ImportError:
|
||||||
|
has_mlx_whisper = False
|
||||||
|
|
||||||
|
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||||
|
if has_mps and has_mlx_whisper:
|
||||||
|
return InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-tiny-mlx",
|
||||||
|
inference_framework=InferenceAsrFramework.MLX,
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return InlineAsrNativeWhisperOptions(
|
||||||
repo_id="tiny",
|
repo_id="tiny",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@@ -26,7 +60,11 @@ WHISPER_TINY = InlineAsrNativeWhisperOptions(
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_new_tokens=256,
|
max_new_tokens=256,
|
||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Create the model instance
|
||||||
|
WHISPER_TINY = _get_whisper_tiny_model()
|
||||||
|
|
||||||
def _get_whisper_small_model():
|
def _get_whisper_small_model():
|
||||||
"""
|
"""
|
||||||
@@ -77,7 +115,41 @@ def _get_whisper_small_model():
|
|||||||
# Create the model instance
|
# Create the model instance
|
||||||
WHISPER_SMALL = _get_whisper_small_model()
|
WHISPER_SMALL = _get_whisper_small_model()
|
||||||
|
|
||||||
WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
|
def _get_whisper_medium_model():
|
||||||
|
"""
|
||||||
|
Get the best Whisper Medium model for the current hardware.
|
||||||
|
|
||||||
|
Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available,
|
||||||
|
otherwise falls back to native Whisper Medium.
|
||||||
|
"""
|
||||||
|
# Check if MPS is available (Apple Silicon)
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
|
except ImportError:
|
||||||
|
has_mps = False
|
||||||
|
|
||||||
|
# Check if mlx-whisper is available
|
||||||
|
try:
|
||||||
|
import mlx_whisper # type: ignore
|
||||||
|
has_mlx_whisper = True
|
||||||
|
except ImportError:
|
||||||
|
has_mlx_whisper = False
|
||||||
|
|
||||||
|
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||||
|
if has_mps and has_mlx_whisper:
|
||||||
|
return InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-medium-mlx-8bit",
|
||||||
|
inference_framework=InferenceAsrFramework.MLX,
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return InlineAsrNativeWhisperOptions(
|
||||||
repo_id="medium",
|
repo_id="medium",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@@ -86,7 +158,11 @@ WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_new_tokens=256,
|
max_new_tokens=256,
|
||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Create the model instance
|
||||||
|
WHISPER_MEDIUM = _get_whisper_medium_model()
|
||||||
|
|
||||||
def _get_whisper_base_model():
|
def _get_whisper_base_model():
|
||||||
"""
|
"""
|
||||||
@@ -137,7 +213,41 @@ def _get_whisper_base_model():
|
|||||||
# Create the model instance
|
# Create the model instance
|
||||||
WHISPER_BASE = _get_whisper_base_model()
|
WHISPER_BASE = _get_whisper_base_model()
|
||||||
|
|
||||||
WHISPER_LARGE = InlineAsrNativeWhisperOptions(
|
def _get_whisper_large_model():
|
||||||
|
"""
|
||||||
|
Get the best Whisper Large model for the current hardware.
|
||||||
|
|
||||||
|
Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available,
|
||||||
|
otherwise falls back to native Whisper Large.
|
||||||
|
"""
|
||||||
|
# Check if MPS is available (Apple Silicon)
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
|
except ImportError:
|
||||||
|
has_mps = False
|
||||||
|
|
||||||
|
# Check if mlx-whisper is available
|
||||||
|
try:
|
||||||
|
import mlx_whisper # type: ignore
|
||||||
|
has_mlx_whisper = True
|
||||||
|
except ImportError:
|
||||||
|
has_mlx_whisper = False
|
||||||
|
|
||||||
|
# Use MLX Whisper if both MPS and mlx-whisper are available
|
||||||
|
if has_mps and has_mlx_whisper:
|
||||||
|
return InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-large-mlx-8bit",
|
||||||
|
inference_framework=InferenceAsrFramework.MLX,
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return InlineAsrNativeWhisperOptions(
|
||||||
repo_id="large",
|
repo_id="large",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@@ -146,7 +256,11 @@ WHISPER_LARGE = InlineAsrNativeWhisperOptions(
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_new_tokens=256,
|
max_new_tokens=256,
|
||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Create the model instance
|
||||||
|
WHISPER_LARGE = _get_whisper_large_model()
|
||||||
|
|
||||||
def _get_whisper_turbo_model():
|
def _get_whisper_turbo_model():
|
||||||
"""
|
"""
|
||||||
|
|||||||
129
docs/examples/mlx_whisper_example.py
vendored
Normal file
129
docs/examples/mlx_whisper_example.py
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Example script demonstrating MLX Whisper integration for Apple Silicon.
|
||||||
|
|
||||||
|
This script shows how to use the MLX Whisper models for speech recognition
|
||||||
|
on Apple Silicon devices with optimized performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the repository root to the path so we can import docling
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from docling.datamodel.asr_model_specs import (
|
||||||
|
WHISPER_TINY,
|
||||||
|
WHISPER_BASE,
|
||||||
|
WHISPER_SMALL,
|
||||||
|
WHISPER_MEDIUM,
|
||||||
|
WHISPER_LARGE,
|
||||||
|
WHISPER_TURBO,
|
||||||
|
)
|
||||||
|
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
||||||
|
from docling.datamodel.base_models import InputFormat
|
||||||
|
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||||
|
from docling.pipeline.asr_pipeline import AsrPipeline
|
||||||
|
from docling.document_converter import DocumentConverter, AudioFormatOption
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"):
|
||||||
|
"""
|
||||||
|
Transcribe audio using Whisper models with automatic MLX optimization for Apple Silicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_file_path: Path to the audio file to transcribe
|
||||||
|
model_size: Size of the Whisper model to use
|
||||||
|
("tiny", "base", "small", "medium", "large", "turbo")
|
||||||
|
Note: MLX optimization is automatically used on Apple Silicon when available
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The transcribed text
|
||||||
|
"""
|
||||||
|
# Select the appropriate Whisper model (automatically uses MLX on Apple Silicon)
|
||||||
|
model_map = {
|
||||||
|
"tiny": WHISPER_TINY,
|
||||||
|
"base": WHISPER_BASE,
|
||||||
|
"small": WHISPER_SMALL,
|
||||||
|
"medium": WHISPER_MEDIUM,
|
||||||
|
"large": WHISPER_LARGE,
|
||||||
|
"turbo": WHISPER_TURBO,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_size not in model_map:
|
||||||
|
raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}")
|
||||||
|
|
||||||
|
asr_options = model_map[model_size]
|
||||||
|
|
||||||
|
# Configure accelerator options for Apple Silicon
|
||||||
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||||
|
|
||||||
|
# Create pipeline options
|
||||||
|
pipeline_options = AsrPipelineOptions(
|
||||||
|
asr_options=asr_options,
|
||||||
|
accelerator_options=accelerator_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create document converter with MLX Whisper configuration
|
||||||
|
converter = DocumentConverter(
|
||||||
|
format_options={
|
||||||
|
InputFormat.AUDIO: AudioFormatOption(
|
||||||
|
pipeline_cls=AsrPipeline,
|
||||||
|
pipeline_options=pipeline_options,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run transcription
|
||||||
|
result = converter.convert(Path(audio_file_path))
|
||||||
|
|
||||||
|
if result.status.value == "success":
|
||||||
|
# Extract text from the document
|
||||||
|
text_content = []
|
||||||
|
for item in result.document.texts:
|
||||||
|
text_content.append(item.text)
|
||||||
|
|
||||||
|
return "\n".join(text_content)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Transcription failed: {result.status}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to demonstrate MLX Whisper usage."""
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python mlx_whisper_example.py <audio_file_path> [model_size]")
|
||||||
|
print("Model sizes: tiny, base, small, medium, large, turbo")
|
||||||
|
print("Example: python mlx_whisper_example.py audio.wav base")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
audio_file_path = sys.argv[1]
|
||||||
|
model_size = sys.argv[2] if len(sys.argv) > 2 else "base"
|
||||||
|
|
||||||
|
if not Path(audio_file_path).exists():
|
||||||
|
print(f"Error: Audio file '{audio_file_path}' not found.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...")
|
||||||
|
print("Note: MLX optimization is automatically used on Apple Silicon when available.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size)
|
||||||
|
|
||||||
|
print("Transcription Result:")
|
||||||
|
print("=" * 50)
|
||||||
|
print(transcribed_text)
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
print("Please install mlx-whisper: pip install mlx-whisper")
|
||||||
|
print("Or install with uv: uv sync --extra asr")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during transcription: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
206
tests/test_asr_mlx_whisper.py
Normal file
206
tests/test_asr_mlx_whisper.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from docling.datamodel.asr_model_specs import (
|
||||||
|
WHISPER_TINY,
|
||||||
|
WHISPER_BASE,
|
||||||
|
WHISPER_SMALL,
|
||||||
|
WHISPER_MEDIUM,
|
||||||
|
WHISPER_LARGE,
|
||||||
|
WHISPER_TURBO,
|
||||||
|
)
|
||||||
|
from docling.datamodel.pipeline_options_asr_model import (
|
||||||
|
InferenceAsrFramework,
|
||||||
|
InlineAsrMlxWhisperOptions,
|
||||||
|
)
|
||||||
|
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
||||||
|
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
||||||
|
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||||
|
|
||||||
|
|
||||||
|
class TestMlxWhisperIntegration:
|
||||||
|
"""Test MLX Whisper model integration."""
|
||||||
|
|
||||||
|
def test_mlx_whisper_options_creation(self):
|
||||||
|
"""Test that MLX Whisper options are created correctly."""
|
||||||
|
options = InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-tiny-mlx",
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert options.inference_framework == InferenceAsrFramework.MLX
|
||||||
|
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
|
||||||
|
assert options.language == "en"
|
||||||
|
assert options.task == "transcribe"
|
||||||
|
assert options.word_timestamps is True
|
||||||
|
assert AcceleratorDevice.MPS in options.supported_devices
|
||||||
|
|
||||||
|
def test_whisper_models_auto_select_mlx(self):
|
||||||
|
"""Test that Whisper models automatically select MLX when MPS and mlx-whisper are available."""
|
||||||
|
# This test verifies that the models are correctly configured
|
||||||
|
# In a real Apple Silicon environment with mlx-whisper installed,
|
||||||
|
# these models would automatically use MLX
|
||||||
|
|
||||||
|
# Check that the models exist and have the correct structure
|
||||||
|
assert hasattr(WHISPER_TURBO, 'inference_framework')
|
||||||
|
assert hasattr(WHISPER_TURBO, 'repo_id')
|
||||||
|
|
||||||
|
assert hasattr(WHISPER_BASE, 'inference_framework')
|
||||||
|
assert hasattr(WHISPER_BASE, 'repo_id')
|
||||||
|
|
||||||
|
assert hasattr(WHISPER_SMALL, 'inference_framework')
|
||||||
|
assert hasattr(WHISPER_SMALL, 'repo_id')
|
||||||
|
|
||||||
|
@patch('builtins.__import__')
|
||||||
|
def test_mlx_whisper_model_initialization(self, mock_import):
|
||||||
|
"""Test MLX Whisper model initialization."""
|
||||||
|
# Mock the mlx_whisper import
|
||||||
|
mock_mlx_whisper = Mock()
|
||||||
|
mock_import.return_value = mock_mlx_whisper
|
||||||
|
|
||||||
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||||
|
asr_options = InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-tiny-mlx",
|
||||||
|
inference_framework=InferenceAsrFramework.MLX,
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = _MlxWhisperModel(
|
||||||
|
enabled=True,
|
||||||
|
artifacts_path=None,
|
||||||
|
accelerator_options=accelerator_options,
|
||||||
|
asr_options=asr_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model.enabled is True
|
||||||
|
assert model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||||
|
assert model.language == "en"
|
||||||
|
assert model.task == "transcribe"
|
||||||
|
assert model.word_timestamps is True
|
||||||
|
|
||||||
|
def test_mlx_whisper_model_import_error(self):
|
||||||
|
"""Test that ImportError is raised when mlx-whisper is not available."""
|
||||||
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||||
|
asr_options = InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-tiny-mlx",
|
||||||
|
inference_framework=InferenceAsrFramework.MLX,
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
|
||||||
|
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
||||||
|
_MlxWhisperModel(
|
||||||
|
enabled=True,
|
||||||
|
artifacts_path=None,
|
||||||
|
accelerator_options=accelerator_options,
|
||||||
|
asr_options=asr_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('builtins.__import__')
|
||||||
|
def test_mlx_whisper_transcribe(self, mock_import):
|
||||||
|
"""Test MLX Whisper transcription method."""
|
||||||
|
# Mock the mlx_whisper module and its transcribe function
|
||||||
|
mock_mlx_whisper = Mock()
|
||||||
|
mock_import.return_value = mock_mlx_whisper
|
||||||
|
|
||||||
|
# Mock the transcribe result
|
||||||
|
mock_result = {
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 2.5,
|
||||||
|
"text": "Hello world",
|
||||||
|
"words": [
|
||||||
|
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
||||||
|
{"start": 0.5, "end": 1.0, "word": "world"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_mlx_whisper.transcribe.return_value = mock_result
|
||||||
|
|
||||||
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||||
|
asr_options = InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-tiny-mlx",
|
||||||
|
inference_framework=InferenceAsrFramework.MLX,
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = _MlxWhisperModel(
|
||||||
|
enabled=True,
|
||||||
|
artifacts_path=None,
|
||||||
|
accelerator_options=accelerator_options,
|
||||||
|
asr_options=asr_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test transcription
|
||||||
|
audio_path = Path("test_audio.wav")
|
||||||
|
result = model.transcribe(audio_path)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].start_time == 0.0
|
||||||
|
assert result[0].end_time == 2.5
|
||||||
|
assert result[0].text == "Hello world"
|
||||||
|
assert len(result[0].words) == 2
|
||||||
|
assert result[0].words[0].text == "Hello"
|
||||||
|
assert result[0].words[1].text == "world"
|
||||||
|
|
||||||
|
# Verify mlx_whisper.transcribe was called with correct parameters
|
||||||
|
mock_mlx_whisper.transcribe.assert_called_once_with(
|
||||||
|
str(audio_path),
|
||||||
|
path_or_hf_repo="mlx-community/whisper-tiny-mlx",
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('builtins.__import__')
|
||||||
|
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
||||||
|
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
||||||
|
# Mock the mlx_whisper import
|
||||||
|
mock_mlx_whisper = Mock()
|
||||||
|
mock_import.return_value = mock_mlx_whisper
|
||||||
|
|
||||||
|
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
|
||||||
|
asr_options = InlineAsrMlxWhisperOptions(
|
||||||
|
repo_id="mlx-community/whisper-tiny-mlx",
|
||||||
|
inference_framework=InferenceAsrFramework.MLX,
|
||||||
|
language="en",
|
||||||
|
task="transcribe",
|
||||||
|
word_timestamps=True,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
logprob_threshold=-1.0,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
)
|
||||||
|
pipeline_options = AsrPipelineOptions(
|
||||||
|
asr_options=asr_options,
|
||||||
|
accelerator_options=accelerator_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = AsrPipeline(pipeline_options)
|
||||||
|
assert isinstance(pipeline._model, _MlxWhisperModel)
|
||||||
|
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"
|
||||||
Reference in New Issue
Block a user