From 96c669d155c8e9bd6455ecff4720933ad7d9e7cb Mon Sep 17 00:00:00 2001 From: Ken Steele Date: Thu, 2 Oct 2025 05:51:51 -0700 Subject: [PATCH] fixed linter issue --- .../asr_pipeline_performance_comparison.py | 299 ++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 docs/examples/asr_pipeline_performance_comparison.py diff --git a/docs/examples/asr_pipeline_performance_comparison.py b/docs/examples/asr_pipeline_performance_comparison.py new file mode 100644 index 00000000..1d078408 --- /dev/null +++ b/docs/examples/asr_pipeline_performance_comparison.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +""" +Performance comparison between CPU and MLX Whisper on Apple Silicon. + +This script compares the performance of: +1. Native Whisper (forced to CPU) +2. MLX Whisper (Apple Silicon optimized) + +Both use the same model size for fair comparison. +""" + +import argparse +import sys +import time +from pathlib import Path + +# Add the repository root to the path so we can import docling +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions +from docling.datamodel.base_models import InputFormat +from docling.datamodel.pipeline_options import AsrPipelineOptions +from docling.datamodel.pipeline_options_asr_model import ( + InferenceAsrFramework, + InlineAsrMlxWhisperOptions, + InlineAsrNativeWhisperOptions, +) +from docling.document_converter import AudioFormatOption, DocumentConverter +from docling.pipeline.asr_pipeline import AsrPipeline + + +def create_cpu_whisper_options(model_size: str = "turbo"): + """Create native Whisper options forced to CPU.""" + return InlineAsrNativeWhisperOptions( + repo_id=model_size, + inference_framework=InferenceAsrFramework.WHISPER, + verbose=True, + timestamps=True, + word_timestamps=True, + temperature=0.0, + max_new_tokens=256, + max_time_chunk=30.0, + ) + + +def create_mlx_whisper_options(model_size: str = "turbo"): + """Create MLX Whisper options for Apple Silicon.""" + model_map = { + "tiny": "mlx-community/whisper-tiny-mlx", + "small": "mlx-community/whisper-small-mlx", + "base": "mlx-community/whisper-base-mlx", + "medium": "mlx-community/whisper-medium-mlx-8bit", + "large": "mlx-community/whisper-large-mlx-8bit", + "turbo": "mlx-community/whisper-turbo", + } + + return InlineAsrMlxWhisperOptions( + repo_id=model_map[model_size], + inference_framework=InferenceAsrFramework.MLX, + language="en", + task="transcribe", + word_timestamps=True, + no_speech_threshold=0.6, + logprob_threshold=-1.0, + compression_ratio_threshold=2.4, + ) + + +def run_transcription_test( + audio_file: Path, asr_options, device: AcceleratorDevice, test_name: str +): + """Run a single transcription test and return timing results.""" + print(f"\n{'=' * 60}") + print(f"Running {test_name}") + print(f"Device: {device}") + print(f"Model: {asr_options.repo_id}") + print(f"Framework: {asr_options.inference_framework}") + print(f"{'=' * 60}") + + # Create pipeline options + pipeline_options = AsrPipelineOptions( + accelerator_options=AcceleratorOptions(device=device), + asr_options=asr_options, + ) + + # Create document converter + converter = DocumentConverter( + format_options={ + InputFormat.AUDIO: AudioFormatOption( + pipeline_cls=AsrPipeline, + pipeline_options=pipeline_options, + ) + } + ) + + # Run transcription with timing + start_time = time.time() + try: + result = converter.convert(audio_file) + end_time = time.time() + + duration = end_time - start_time + + if result.status.value == "success": + # Extract text for verification + text_content = [] + for item in result.document.texts: + text_content.append(item.text) + + print(f"✅ Success! Duration: {duration:.2f} seconds") + print(f"Transcribed text: {''.join(text_content)[:100]}...") + return duration, True + else: + print(f"❌ Failed! Status: {result.status}") + return duration, False + + except Exception as e: + end_time = time.time() + duration = end_time - start_time + print(f"❌ Error: {e}") + return duration, False + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Performance comparison between CPU and MLX Whisper on Apple Silicon", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + +# Use default test audio file +python asr_pipeline_performance_comparison.py + +# Use your own audio file +python asr_pipeline_performance_comparison.py --audio /path/to/your/audio.mp3 + +# Use a different audio file from the tests directory +python asr_pipeline_performance_comparison.py --audio tests/data/audio/another_sample.wav + """, + ) + + parser.add_argument( + "--audio", + type=str, + help="Path to audio file for testing (default: tests/data/audio/sample_10s.mp3)", + ) + + return parser.parse_args() + + +def main(): + """Run performance comparison between CPU and MLX Whisper.""" + args = parse_args() + + # Check if we're on Apple Silicon + try: + import torch + + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() + except ImportError: + has_mps = False + + try: + import mlx_whisper + + has_mlx_whisper = True + except ImportError: + has_mlx_whisper = False + + print("ASR Pipeline Performance Comparison") + print("=" * 50) + print(f"Apple Silicon (MPS) available: {has_mps}") + print(f"MLX Whisper available: {has_mlx_whisper}") + + if not has_mps: + print("❌ This test requires Apple Silicon (MPS) to be meaningful.") + print(" MLX Whisper is optimized for Apple Silicon devices.") + sys.exit(1) + + if not has_mlx_whisper: + print("❌ MLX Whisper is not installed.") + print(" Install with: pip install mlx-whisper") + print(" Or: uv sync --extra asr") + sys.exit(1) + + # Determine audio file path + if args.audio: + audio_file = Path(args.audio) + if not audio_file.is_absolute(): + # If relative path, make it relative to the script's directory + audio_file = Path(__file__).parent.parent.parent / audio_file + else: + # Use default test audio file + audio_file = ( + Path(__file__).parent.parent.parent + / "tests" + / "data" + / "audio" + / "sample_10s.mp3" + ) + + if not audio_file.exists(): + print(f"❌ Audio file not found: {audio_file}") + print(" Please check the path and try again.") + sys.exit(1) + + print(f"Using test audio: {audio_file}") + print(f"File size: {audio_file.stat().st_size / 1024:.1f} KB") + + # Test different model sizes + model_sizes = ["tiny", "base", "turbo"] + results = {} + + for model_size in model_sizes: + print(f"\n{'#' * 80}") + print(f"Testing model size: {model_size}") + print(f"{'#' * 80}") + + model_results = {} + + # Test 1: Native Whisper (forced to CPU) + cpu_options = create_cpu_whisper_options(model_size) + cpu_duration, cpu_success = run_transcription_test( + audio_file, + cpu_options, + AcceleratorDevice.CPU, + f"Native Whisper {model_size} (CPU)", + ) + model_results["cpu"] = {"duration": cpu_duration, "success": cpu_success} + + # Test 2: MLX Whisper (Apple Silicon optimized) + mlx_options = create_mlx_whisper_options(model_size) + mlx_duration, mlx_success = run_transcription_test( + audio_file, + mlx_options, + AcceleratorDevice.MPS, + f"MLX Whisper {model_size} (MPS)", + ) + model_results["mlx"] = {"duration": mlx_duration, "success": mlx_success} + + results[model_size] = model_results + + # Print summary + print(f"\n{'#' * 80}") + print("PERFORMANCE COMPARISON SUMMARY") + print(f"{'#' * 80}") + print( + f"{'Model':<10} {'CPU (sec)':<12} {'MLX (sec)':<12} {'Speedup':<12} {'Status':<10}" + ) + print("-" * 80) + + for model_size, model_results in results.items(): + cpu_duration = model_results["cpu"]["duration"] + mlx_duration = model_results["mlx"]["duration"] + cpu_success = model_results["cpu"]["success"] + mlx_success = model_results["mlx"]["success"] + + if cpu_success and mlx_success: + speedup = cpu_duration / mlx_duration + status = "✅ Both OK" + elif cpu_success: + speedup = float("inf") + status = "❌ MLX Failed" + elif mlx_success: + speedup = 0 + status = "❌ CPU Failed" + else: + speedup = 0 + status = "❌ Both Failed" + + print( + f"{model_size:<10} {cpu_duration:<12.2f} {mlx_duration:<12.2f} {speedup:<12.2f}x {status:<10}" + ) + + # Calculate overall improvement + successful_tests = [ + (r["cpu"]["duration"], r["mlx"]["duration"]) + for r in results.values() + if r["cpu"]["success"] and r["mlx"]["success"] + ] + + if successful_tests: + avg_cpu = sum(cpu for cpu, mlx in successful_tests) / len(successful_tests) + avg_mlx = sum(mlx for cpu, mlx in successful_tests) / len(successful_tests) + avg_speedup = avg_cpu / avg_mlx + + print("-" * 80) + print( + f"{'AVERAGE':<10} {avg_cpu:<12.2f} {avg_mlx:<12.2f} {avg_speedup:<12.2f}x {'Overall':<10}" + ) + + print(f"\n🎯 MLX Whisper provides {avg_speedup:.1f}x average speedup over CPU!") + else: + print("\n❌ No successful comparisons available.") + + +if __name__ == "__main__": + main()