Files
docling/docs/examples/asr_pipeline_performance_comparison.py
2025-10-18 11:00:28 -07:00

300 lines
9.2 KiB
Python
Vendored

#!/usr/bin/env python3
"""
Performance comparison between CPU and MLX Whisper on Apple Silicon.
This script compares the performance of:
1. Native Whisper (forced to CPU)
2. MLX Whisper (Apple Silicon optimized)
Both use the same model size for fair comparison.
"""
import argparse
import sys
import time
from pathlib import Path
# Add the repository root to the path so we can import docling
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.datamodel.pipeline_options_asr_model import (
InferenceAsrFramework,
InlineAsrMlxWhisperOptions,
InlineAsrNativeWhisperOptions,
)
from docling.document_converter import AudioFormatOption, DocumentConverter
from docling.pipeline.asr_pipeline import AsrPipeline
def create_cpu_whisper_options(model_size: str = "turbo"):
"""Create native Whisper options forced to CPU."""
return InlineAsrNativeWhisperOptions(
repo_id=model_size,
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperature=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
def create_mlx_whisper_options(model_size: str = "turbo"):
"""Create MLX Whisper options for Apple Silicon."""
model_map = {
"tiny": "mlx-community/whisper-tiny-mlx",
"small": "mlx-community/whisper-small-mlx",
"base": "mlx-community/whisper-base-mlx",
"medium": "mlx-community/whisper-medium-mlx-8bit",
"large": "mlx-community/whisper-large-mlx-8bit",
"turbo": "mlx-community/whisper-turbo",
}
return InlineAsrMlxWhisperOptions(
repo_id=model_map[model_size],
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
def run_transcription_test(
audio_file: Path, asr_options, device: AcceleratorDevice, test_name: str
):
"""Run a single transcription test and return timing results."""
print(f"\n{'=' * 60}")
print(f"Running {test_name}")
print(f"Device: {device}")
print(f"Model: {asr_options.repo_id}")
print(f"Framework: {asr_options.inference_framework}")
print(f"{'=' * 60}")
# Create pipeline options
pipeline_options = AsrPipelineOptions(
accelerator_options=AcceleratorOptions(device=device),
asr_options=asr_options,
)
# Create document converter
converter = DocumentConverter(
format_options={
InputFormat.AUDIO: AudioFormatOption(
pipeline_cls=AsrPipeline,
pipeline_options=pipeline_options,
)
}
)
# Run transcription with timing
start_time = time.time()
try:
result = converter.convert(audio_file)
end_time = time.time()
duration = end_time - start_time
if result.status.value == "success":
# Extract text for verification
text_content = []
for item in result.document.texts:
text_content.append(item.text)
print(f"✅ Success! Duration: {duration:.2f} seconds")
print(f"Transcribed text: {''.join(text_content)[:100]}...")
return duration, True
else:
print(f"❌ Failed! Status: {result.status}")
return duration, False
except Exception as e:
end_time = time.time()
duration = end_time - start_time
print(f"❌ Error: {e}")
return duration, False
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Performance comparison between CPU and MLX Whisper on Apple Silicon",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Use default test audio file
python asr_pipeline_performance_comparison.py
# Use your own audio file
python asr_pipeline_performance_comparison.py --audio /path/to/your/audio.mp3
# Use a different audio file from the tests directory
python asr_pipeline_performance_comparison.py --audio tests/data/audio/another_sample.wav
""",
)
parser.add_argument(
"--audio",
type=str,
help="Path to audio file for testing (default: tests/data/audio/sample_10s.mp3)",
)
return parser.parse_args()
def main():
"""Run performance comparison between CPU and MLX Whisper."""
args = parse_args()
# Check if we're on Apple Silicon
try:
import torch
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
except ImportError:
has_mps = False
try:
import mlx_whisper
has_mlx_whisper = True
except ImportError:
has_mlx_whisper = False
print("ASR Pipeline Performance Comparison")
print("=" * 50)
print(f"Apple Silicon (MPS) available: {has_mps}")
print(f"MLX Whisper available: {has_mlx_whisper}")
if not has_mps:
print("❌ This test requires Apple Silicon (MPS) to be meaningful.")
print(" MLX Whisper is optimized for Apple Silicon devices.")
sys.exit(1)
if not has_mlx_whisper:
print("❌ MLX Whisper is not installed.")
print(" Install with: pip install mlx-whisper")
print(" Or: uv sync --extra asr")
sys.exit(1)
# Determine audio file path
if args.audio:
audio_file = Path(args.audio)
if not audio_file.is_absolute():
# If relative path, make it relative to the script's directory
audio_file = Path(__file__).parent.parent.parent / audio_file
else:
# Use default test audio file
audio_file = (
Path(__file__).parent.parent.parent
/ "tests"
/ "data"
/ "audio"
/ "sample_10s.mp3"
)
if not audio_file.exists():
print(f"❌ Audio file not found: {audio_file}")
print(" Please check the path and try again.")
sys.exit(1)
print(f"Using test audio: {audio_file}")
print(f"File size: {audio_file.stat().st_size / 1024:.1f} KB")
# Test different model sizes
model_sizes = ["tiny", "base", "turbo"]
results = {}
for model_size in model_sizes:
print(f"\n{'#' * 80}")
print(f"Testing model size: {model_size}")
print(f"{'#' * 80}")
model_results = {}
# Test 1: Native Whisper (forced to CPU)
cpu_options = create_cpu_whisper_options(model_size)
cpu_duration, cpu_success = run_transcription_test(
audio_file,
cpu_options,
AcceleratorDevice.CPU,
f"Native Whisper {model_size} (CPU)",
)
model_results["cpu"] = {"duration": cpu_duration, "success": cpu_success}
# Test 2: MLX Whisper (Apple Silicon optimized)
mlx_options = create_mlx_whisper_options(model_size)
mlx_duration, mlx_success = run_transcription_test(
audio_file,
mlx_options,
AcceleratorDevice.MPS,
f"MLX Whisper {model_size} (MPS)",
)
model_results["mlx"] = {"duration": mlx_duration, "success": mlx_success}
results[model_size] = model_results
# Print summary
print(f"\n{'#' * 80}")
print("PERFORMANCE COMPARISON SUMMARY")
print(f"{'#' * 80}")
print(
f"{'Model':<10} {'CPU (sec)':<12} {'MLX (sec)':<12} {'Speedup':<12} {'Status':<10}"
)
print("-" * 80)
for model_size, model_results in results.items():
cpu_duration = model_results["cpu"]["duration"]
mlx_duration = model_results["mlx"]["duration"]
cpu_success = model_results["cpu"]["success"]
mlx_success = model_results["mlx"]["success"]
if cpu_success and mlx_success:
speedup = cpu_duration / mlx_duration
status = "✅ Both OK"
elif cpu_success:
speedup = float("inf")
status = "❌ MLX Failed"
elif mlx_success:
speedup = 0
status = "❌ CPU Failed"
else:
speedup = 0
status = "❌ Both Failed"
print(
f"{model_size:<10} {cpu_duration:<12.2f} {mlx_duration:<12.2f} {speedup:<12.2f}x {status:<10}"
)
# Calculate overall improvement
successful_tests = [
(r["cpu"]["duration"], r["mlx"]["duration"])
for r in results.values()
if r["cpu"]["success"] and r["mlx"]["success"]
]
if successful_tests:
avg_cpu = sum(cpu for cpu, mlx in successful_tests) / len(successful_tests)
avg_mlx = sum(mlx for cpu, mlx in successful_tests) / len(successful_tests)
avg_speedup = avg_cpu / avg_mlx
print("-" * 80)
print(
f"{'AVERAGE':<10} {avg_cpu:<12.2f} {avg_mlx:<12.2f} {avg_speedup:<12.2f}x {'Overall':<10}"
)
print(f"\n🎯 MLX Whisper provides {avg_speedup:.1f}x average speedup over CPU!")
else:
print("\n❌ No successful comparisons available.")
if __name__ == "__main__":
main()