mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
* add mlx-whisper support * added mlx-whisper example and test. update docling cli to use MLX automatically if present. * fix pre-commit checks and added proper type safety * fixed linter issue * DCO Remediation Commit for Ken Steele <ksteele@gmail.com> I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: a979a680e1dc2fee8461401335cfb5dda8cfdd98 I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 9827068382ca946fe1387ed83f747ae509fcf229 I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: ebbeb45c7dc266260e1fad6bdb54a7041f8aeed4 I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 2f6fd3cf46c8ca0bb98810191578278f1df87aa3 Signed-off-by: Ken Steele <ksteele@gmail.com> * fix unit tests and code coverage for CI * DCO Remediation Commit for Ken Steele <ksteele@gmail.com> I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit: 5e61bf11139a2133978db2c8d306be6289aed732 Signed-off-by: Ken Steele <ksteele@gmail.com> * fix CI example test - mlx_whisper_example.py defaults to tests/data/audio/sample_10s.mp3 if no args specified. Signed-off-by: Ken Steele <ksteele@gmail.com> * refactor: centralize audio file extensions and MIME types in base_models.py - Move audio file extensions from CLI hardcoded set to FormatToExtensions[InputFormat.AUDIO] - Add support for additional audio formats: m4a, aac, ogg, flac, mp4, avi, mov - Update FormatToMimeType mapping to include MIME types for all audio formats - Update CLI auto-detection to use centralized FormatToExtensions mapping - Add comprehensive tests for audio file auto-detection and pipeline selection - Ensure explicit pipeline choices are not overridden by auto-detection Fixes issue where only .mp3 and .wav files were processed as audio despite CLI auto-detection working for all formats. The document converter now properly recognizes all audio formats through MIME type detection. Addresses review comments: - Centralizes audio extensions in base_models.py as suggested - Maintains existing auto-detection behavior while using centralized data - Adds proper test coverage for the audio detection functionality All examples and tests pass with the new centralized approach. All audio formats (mp3, wav, m4a, aac, ogg, flac, mp4, avi, mov) now work correctly. Signed-off-by: Ken Steele <ksteele@gmail.com> * feat: address reviewer feedback - improve CLI auto-detection and add explicit model options Review feedback addressed: 1. Fix CLI auto-detection to only switch to ASR pipeline when ALL files are audio - Previously switched if ANY file was audio, now requires ALL files to be audio - Added warning for mixed file types with guidance to use --pipeline asr 2. Add explicit WHISPER_X_MLX and WHISPER_X_NATIVE model options - Users can now force specific implementations if desired - Auto-selecting models (WHISPER_BASE, etc.) still choose best for hardware - Added 12 new explicit model options: _MLX and _NATIVE variants for each size CLI now supports: - Auto-selecting: whisper_tiny, whisper_base, etc. (choose best for hardware) - Explicit MLX: whisper_tiny_mlx, whisper_base_mlx, etc. (force MLX) - Explicit Native: whisper_tiny_native, whisper_base_native, etc. (force native) Addresses reviewer comments from @dolfim-ibm Signed-off-by: Ken Steele <ksteele@gmail.com> * DCO Remediation Commit for Ken Steele <ksteele@gmail.com> I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:c60e72d2b5I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:94803317a3I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:21905e8aceI, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:96c669d155I, Ken Steele <ksteele@gmail.com>, hereby add my Signed-off-by to this commit:8371c060eaSigned-off-by: Ken Steele <ksteele@gmail.com> * test(asr): add coverage for MLX options, pipeline helpers, and VLM prompts - tests/test_asr_mlx_whisper.py: verify explicit MLX options (framework, repo ids) - tests/test_asr_pipeline.py: cover _has_text/_determine_status and backend support with proper InputDocument/NoOpBackend wiring - tests/test_interfaces.py: add BaseVlmPageModel.formulate_prompt tests (RAW/NONE/CHAT, invalid style), with minimal InlineVlmOptions scaffold Improves reliability of ASR and VLM components by validating configuration paths and helper logic. Signed-off-by: Ken Steele <ksteele@gmail.com> * test(asr): broaden coverage for model selection, pipeline flows, and VLM prompts - tests/test_asr_mlx_whisper.py - Add MLX/native selector coverage across all Whisper sizes - Validate repo_id choices under MLX and Native paths - Cover fallback path when MPS unavailable and mlx_whisper missing - tests/test_asr_pipeline.py - Relax silent-audio assertion to accept PARTIAL_SUCCESS or SUCCESS - Force CPU native path in helper tests to avoid torch in device selection - Add language handling tests for native/MLX transcribe - Cover native run success (BytesIO) and failure (exception) branches - Cover MLX run success/failure branches with mocked transcribe - Add init path coverage with artifacts_path - tests/test_interfaces.py - Add focused VLM prompt tests (NONE/CHAT variants) Result: all tests passing with significantly improved coverage for ASR model selectors, pipeline execution paths, and VLM prompt formulation. Signed-off-by: Ken Steele <ksteele@gmail.com> * simplify ASR model settings (no pipeline detection needed) Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * clean up disk space in runners Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Ken Steele <ksteele@gmail.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
312 lines
9.9 KiB
Python
Vendored
312 lines
9.9 KiB
Python
Vendored
#!/usr/bin/env python3
|
|
"""
|
|
Performance comparison between CPU and MLX Whisper on Apple Silicon.
|
|
|
|
This script compares the performance of:
|
|
1. Native Whisper (forced to CPU)
|
|
2. MLX Whisper (Apple Silicon optimized)
|
|
|
|
Both use the same model size for fair comparison.
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Add the repository root to the path so we can import docling
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
|
from docling.datamodel.base_models import InputFormat
|
|
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
|
from docling.datamodel.pipeline_options_asr_model import (
|
|
InferenceAsrFramework,
|
|
InlineAsrMlxWhisperOptions,
|
|
InlineAsrNativeWhisperOptions,
|
|
)
|
|
from docling.document_converter import AudioFormatOption, DocumentConverter
|
|
from docling.pipeline.asr_pipeline import AsrPipeline
|
|
|
|
|
|
def create_cpu_whisper_options(model_size: str = "turbo"):
|
|
"""Create native Whisper options forced to CPU."""
|
|
return InlineAsrNativeWhisperOptions(
|
|
repo_id=model_size,
|
|
inference_framework=InferenceAsrFramework.WHISPER,
|
|
verbose=True,
|
|
timestamps=True,
|
|
word_timestamps=True,
|
|
temperature=0.0,
|
|
max_new_tokens=256,
|
|
max_time_chunk=30.0,
|
|
)
|
|
|
|
|
|
def create_mlx_whisper_options(model_size: str = "turbo"):
|
|
"""Create MLX Whisper options for Apple Silicon."""
|
|
model_map = {
|
|
"tiny": "mlx-community/whisper-tiny-mlx",
|
|
"small": "mlx-community/whisper-small-mlx",
|
|
"base": "mlx-community/whisper-base-mlx",
|
|
"medium": "mlx-community/whisper-medium-mlx-8bit",
|
|
"large": "mlx-community/whisper-large-mlx-8bit",
|
|
"turbo": "mlx-community/whisper-turbo",
|
|
}
|
|
|
|
return InlineAsrMlxWhisperOptions(
|
|
repo_id=model_map[model_size],
|
|
inference_framework=InferenceAsrFramework.MLX,
|
|
language="en",
|
|
task="transcribe",
|
|
word_timestamps=True,
|
|
no_speech_threshold=0.6,
|
|
logprob_threshold=-1.0,
|
|
compression_ratio_threshold=2.4,
|
|
)
|
|
|
|
|
|
def run_transcription_test(
|
|
audio_file: Path, asr_options, device: AcceleratorDevice, test_name: str
|
|
):
|
|
"""Run a single transcription test and return timing results."""
|
|
print(f"\n{'=' * 60}")
|
|
print(f"Running {test_name}")
|
|
print(f"Device: {device}")
|
|
print(f"Model: {asr_options.repo_id}")
|
|
print(f"Framework: {asr_options.inference_framework}")
|
|
print(f"{'=' * 60}")
|
|
|
|
# Create pipeline options
|
|
pipeline_options = AsrPipelineOptions(
|
|
accelerator_options=AcceleratorOptions(device=device),
|
|
asr_options=asr_options,
|
|
)
|
|
|
|
# Create document converter
|
|
converter = DocumentConverter(
|
|
format_options={
|
|
InputFormat.AUDIO: AudioFormatOption(
|
|
pipeline_cls=AsrPipeline,
|
|
pipeline_options=pipeline_options,
|
|
)
|
|
}
|
|
)
|
|
|
|
# Run transcription with timing
|
|
start_time = time.time()
|
|
try:
|
|
result = converter.convert(audio_file)
|
|
end_time = time.time()
|
|
|
|
duration = end_time - start_time
|
|
|
|
if result.status.value == "success":
|
|
# Extract text for verification
|
|
text_content = []
|
|
for item in result.document.texts:
|
|
text_content.append(item.text)
|
|
|
|
print(f"✅ Success! Duration: {duration:.2f} seconds")
|
|
print(f"Transcribed text: {''.join(text_content)[:100]}...")
|
|
return duration, True
|
|
else:
|
|
print(f"❌ Failed! Status: {result.status}")
|
|
return duration, False
|
|
|
|
except Exception as e:
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
print(f"❌ Error: {e}")
|
|
return duration, False
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Performance comparison between CPU and MLX Whisper on Apple Silicon",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
|
|
# Use default test audio file
|
|
python asr_pipeline_performance_comparison.py
|
|
|
|
# Use your own audio file
|
|
python asr_pipeline_performance_comparison.py --audio /path/to/your/audio.mp3
|
|
|
|
# Use a different audio file from the tests directory
|
|
python asr_pipeline_performance_comparison.py --audio tests/data/audio/another_sample.wav
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--audio",
|
|
type=str,
|
|
help="Path to audio file for testing (default: tests/data/audio/sample_10s.mp3)",
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
"""Run performance comparison between CPU and MLX Whisper."""
|
|
args = parse_args()
|
|
|
|
# Check if we're on Apple Silicon
|
|
try:
|
|
import torch
|
|
|
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
|
except ImportError:
|
|
has_mps = False
|
|
|
|
try:
|
|
import mlx_whisper
|
|
|
|
has_mlx_whisper = True
|
|
except ImportError:
|
|
has_mlx_whisper = False
|
|
|
|
print("ASR Pipeline Performance Comparison")
|
|
print("=" * 50)
|
|
print(f"Apple Silicon (MPS) available: {has_mps}")
|
|
print(f"MLX Whisper available: {has_mlx_whisper}")
|
|
|
|
if not has_mps:
|
|
print("⚠️ Apple Silicon (MPS) not available - running CPU-only comparison")
|
|
print(" For MLX Whisper performance benefits, run on Apple Silicon devices")
|
|
print(" MLX Whisper is optimized for Apple Silicon devices.")
|
|
|
|
if not has_mlx_whisper:
|
|
print("⚠️ MLX Whisper not installed - running CPU-only comparison")
|
|
print(" Install with: pip install mlx-whisper")
|
|
print(" Or: uv sync --extra asr")
|
|
print(" For MLX Whisper performance benefits, install the dependency")
|
|
|
|
# Determine audio file path
|
|
if args.audio:
|
|
audio_file = Path(args.audio)
|
|
if not audio_file.is_absolute():
|
|
# If relative path, make it relative to the script's directory
|
|
audio_file = Path(__file__).parent.parent.parent / audio_file
|
|
else:
|
|
# Use default test audio file
|
|
audio_file = (
|
|
Path(__file__).parent.parent.parent
|
|
/ "tests"
|
|
/ "data"
|
|
/ "audio"
|
|
/ "sample_10s.mp3"
|
|
)
|
|
|
|
if not audio_file.exists():
|
|
print(f"❌ Audio file not found: {audio_file}")
|
|
print(" Please check the path and try again.")
|
|
sys.exit(1)
|
|
|
|
print(f"Using test audio: {audio_file}")
|
|
print(f"File size: {audio_file.stat().st_size / 1024:.1f} KB")
|
|
|
|
# Test different model sizes
|
|
model_sizes = ["tiny", "base", "turbo"]
|
|
results = {}
|
|
|
|
for model_size in model_sizes:
|
|
print(f"\n{'#' * 80}")
|
|
print(f"Testing model size: {model_size}")
|
|
print(f"{'#' * 80}")
|
|
|
|
model_results = {}
|
|
|
|
# Test 1: Native Whisper (forced to CPU)
|
|
cpu_options = create_cpu_whisper_options(model_size)
|
|
cpu_duration, cpu_success = run_transcription_test(
|
|
audio_file,
|
|
cpu_options,
|
|
AcceleratorDevice.CPU,
|
|
f"Native Whisper {model_size} (CPU)",
|
|
)
|
|
model_results["cpu"] = {"duration": cpu_duration, "success": cpu_success}
|
|
|
|
# Test 2: MLX Whisper (Apple Silicon optimized) - only if available
|
|
if has_mps and has_mlx_whisper:
|
|
mlx_options = create_mlx_whisper_options(model_size)
|
|
mlx_duration, mlx_success = run_transcription_test(
|
|
audio_file,
|
|
mlx_options,
|
|
AcceleratorDevice.MPS,
|
|
f"MLX Whisper {model_size} (MPS)",
|
|
)
|
|
model_results["mlx"] = {"duration": mlx_duration, "success": mlx_success}
|
|
else:
|
|
print(f"\n{'=' * 60}")
|
|
print(f"Skipping MLX Whisper {model_size} (MPS) - not available")
|
|
print(f"{'=' * 60}")
|
|
model_results["mlx"] = {"duration": 0.0, "success": False}
|
|
|
|
results[model_size] = model_results
|
|
|
|
# Print summary
|
|
print(f"\n{'#' * 80}")
|
|
print("PERFORMANCE COMPARISON SUMMARY")
|
|
print(f"{'#' * 80}")
|
|
print(
|
|
f"{'Model':<10} {'CPU (sec)':<12} {'MLX (sec)':<12} {'Speedup':<12} {'Status':<10}"
|
|
)
|
|
print("-" * 80)
|
|
|
|
for model_size, model_results in results.items():
|
|
cpu_duration = model_results["cpu"]["duration"]
|
|
mlx_duration = model_results["mlx"]["duration"]
|
|
cpu_success = model_results["cpu"]["success"]
|
|
mlx_success = model_results["mlx"]["success"]
|
|
|
|
if cpu_success and mlx_success:
|
|
speedup = cpu_duration / mlx_duration
|
|
status = "✅ Both OK"
|
|
elif cpu_success:
|
|
speedup = float("inf")
|
|
status = "❌ MLX Failed"
|
|
elif mlx_success:
|
|
speedup = 0
|
|
status = "❌ CPU Failed"
|
|
else:
|
|
speedup = 0
|
|
status = "❌ Both Failed"
|
|
|
|
print(
|
|
f"{model_size:<10} {cpu_duration:<12.2f} {mlx_duration:<12.2f} {speedup:<12.2f}x {status:<10}"
|
|
)
|
|
|
|
# Calculate overall improvement
|
|
successful_tests = [
|
|
(r["cpu"]["duration"], r["mlx"]["duration"])
|
|
for r in results.values()
|
|
if r["cpu"]["success"] and r["mlx"]["success"]
|
|
]
|
|
|
|
if successful_tests:
|
|
avg_cpu = sum(cpu for cpu, mlx in successful_tests) / len(successful_tests)
|
|
avg_mlx = sum(mlx for cpu, mlx in successful_tests) / len(successful_tests)
|
|
avg_speedup = avg_cpu / avg_mlx
|
|
|
|
print("-" * 80)
|
|
print(
|
|
f"{'AVERAGE':<10} {avg_cpu:<12.2f} {avg_mlx:<12.2f} {avg_speedup:<12.2f}x {'Overall':<10}"
|
|
)
|
|
|
|
print(f"\n🎯 MLX Whisper provides {avg_speedup:.1f}x average speedup over CPU!")
|
|
else:
|
|
if has_mps and has_mlx_whisper:
|
|
print("\n❌ No successful comparisons available.")
|
|
else:
|
|
print("\n⚠️ MLX Whisper not available - only CPU results shown.")
|
|
print(
|
|
" Install MLX Whisper and run on Apple Silicon for performance comparison."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|