mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
fixed linter issue
This commit is contained in:
299
docs/examples/asr_pipeline_performance_comparison.py
vendored
Normal file
299
docs/examples/asr_pipeline_performance_comparison.py
vendored
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Performance comparison between CPU and MLX Whisper on Apple Silicon.
|
||||
|
||||
This script compares the performance of:
|
||||
1. Native Whisper (forced to CPU)
|
||||
2. MLX Whisper (Apple Silicon optimized)
|
||||
|
||||
Both use the same model size for fair comparison.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add the repository root to the path so we can import docling
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||
from docling.datamodel.pipeline_options_asr_model import (
|
||||
InferenceAsrFramework,
|
||||
InlineAsrMlxWhisperOptions,
|
||||
InlineAsrNativeWhisperOptions,
|
||||
)
|
||||
from docling.document_converter import AudioFormatOption, DocumentConverter
|
||||
from docling.pipeline.asr_pipeline import AsrPipeline
|
||||
|
||||
|
||||
def create_cpu_whisper_options(model_size: str = "turbo"):
|
||||
"""Create native Whisper options forced to CPU."""
|
||||
return InlineAsrNativeWhisperOptions(
|
||||
repo_id=model_size,
|
||||
inference_framework=InferenceAsrFramework.WHISPER,
|
||||
verbose=True,
|
||||
timestamps=True,
|
||||
word_timestamps=True,
|
||||
temperature=0.0,
|
||||
max_new_tokens=256,
|
||||
max_time_chunk=30.0,
|
||||
)
|
||||
|
||||
|
||||
def create_mlx_whisper_options(model_size: str = "turbo"):
|
||||
"""Create MLX Whisper options for Apple Silicon."""
|
||||
model_map = {
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
"small": "mlx-community/whisper-small-mlx",
|
||||
"base": "mlx-community/whisper-base-mlx",
|
||||
"medium": "mlx-community/whisper-medium-mlx-8bit",
|
||||
"large": "mlx-community/whisper-large-mlx-8bit",
|
||||
"turbo": "mlx-community/whisper-turbo",
|
||||
}
|
||||
|
||||
return InlineAsrMlxWhisperOptions(
|
||||
repo_id=model_map[model_size],
|
||||
inference_framework=InferenceAsrFramework.MLX,
|
||||
language="en",
|
||||
task="transcribe",
|
||||
word_timestamps=True,
|
||||
no_speech_threshold=0.6,
|
||||
logprob_threshold=-1.0,
|
||||
compression_ratio_threshold=2.4,
|
||||
)
|
||||
|
||||
|
||||
def run_transcription_test(
|
||||
audio_file: Path, asr_options, device: AcceleratorDevice, test_name: str
|
||||
):
|
||||
"""Run a single transcription test and return timing results."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Running {test_name}")
|
||||
print(f"Device: {device}")
|
||||
print(f"Model: {asr_options.repo_id}")
|
||||
print(f"Framework: {asr_options.inference_framework}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# Create pipeline options
|
||||
pipeline_options = AsrPipelineOptions(
|
||||
accelerator_options=AcceleratorOptions(device=device),
|
||||
asr_options=asr_options,
|
||||
)
|
||||
|
||||
# Create document converter
|
||||
converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.AUDIO: AudioFormatOption(
|
||||
pipeline_cls=AsrPipeline,
|
||||
pipeline_options=pipeline_options,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Run transcription with timing
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = converter.convert(audio_file)
|
||||
end_time = time.time()
|
||||
|
||||
duration = end_time - start_time
|
||||
|
||||
if result.status.value == "success":
|
||||
# Extract text for verification
|
||||
text_content = []
|
||||
for item in result.document.texts:
|
||||
text_content.append(item.text)
|
||||
|
||||
print(f"✅ Success! Duration: {duration:.2f} seconds")
|
||||
print(f"Transcribed text: {''.join(text_content)[:100]}...")
|
||||
return duration, True
|
||||
else:
|
||||
print(f"❌ Failed! Status: {result.status}")
|
||||
return duration, False
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print(f"❌ Error: {e}")
|
||||
return duration, False
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Performance comparison between CPU and MLX Whisper on Apple Silicon",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
|
||||
# Use default test audio file
|
||||
python asr_pipeline_performance_comparison.py
|
||||
|
||||
# Use your own audio file
|
||||
python asr_pipeline_performance_comparison.py --audio /path/to/your/audio.mp3
|
||||
|
||||
# Use a different audio file from the tests directory
|
||||
python asr_pipeline_performance_comparison.py --audio tests/data/audio/another_sample.wav
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--audio",
|
||||
type=str,
|
||||
help="Path to audio file for testing (default: tests/data/audio/sample_10s.mp3)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run performance comparison between CPU and MLX Whisper."""
|
||||
args = parse_args()
|
||||
|
||||
# Check if we're on Apple Silicon
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
except ImportError:
|
||||
has_mps = False
|
||||
|
||||
try:
|
||||
import mlx_whisper
|
||||
|
||||
has_mlx_whisper = True
|
||||
except ImportError:
|
||||
has_mlx_whisper = False
|
||||
|
||||
print("ASR Pipeline Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Apple Silicon (MPS) available: {has_mps}")
|
||||
print(f"MLX Whisper available: {has_mlx_whisper}")
|
||||
|
||||
if not has_mps:
|
||||
print("❌ This test requires Apple Silicon (MPS) to be meaningful.")
|
||||
print(" MLX Whisper is optimized for Apple Silicon devices.")
|
||||
sys.exit(1)
|
||||
|
||||
if not has_mlx_whisper:
|
||||
print("❌ MLX Whisper is not installed.")
|
||||
print(" Install with: pip install mlx-whisper")
|
||||
print(" Or: uv sync --extra asr")
|
||||
sys.exit(1)
|
||||
|
||||
# Determine audio file path
|
||||
if args.audio:
|
||||
audio_file = Path(args.audio)
|
||||
if not audio_file.is_absolute():
|
||||
# If relative path, make it relative to the script's directory
|
||||
audio_file = Path(__file__).parent.parent.parent / audio_file
|
||||
else:
|
||||
# Use default test audio file
|
||||
audio_file = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "tests"
|
||||
/ "data"
|
||||
/ "audio"
|
||||
/ "sample_10s.mp3"
|
||||
)
|
||||
|
||||
if not audio_file.exists():
|
||||
print(f"❌ Audio file not found: {audio_file}")
|
||||
print(" Please check the path and try again.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using test audio: {audio_file}")
|
||||
print(f"File size: {audio_file.stat().st_size / 1024:.1f} KB")
|
||||
|
||||
# Test different model sizes
|
||||
model_sizes = ["tiny", "base", "turbo"]
|
||||
results = {}
|
||||
|
||||
for model_size in model_sizes:
|
||||
print(f"\n{'#' * 80}")
|
||||
print(f"Testing model size: {model_size}")
|
||||
print(f"{'#' * 80}")
|
||||
|
||||
model_results = {}
|
||||
|
||||
# Test 1: Native Whisper (forced to CPU)
|
||||
cpu_options = create_cpu_whisper_options(model_size)
|
||||
cpu_duration, cpu_success = run_transcription_test(
|
||||
audio_file,
|
||||
cpu_options,
|
||||
AcceleratorDevice.CPU,
|
||||
f"Native Whisper {model_size} (CPU)",
|
||||
)
|
||||
model_results["cpu"] = {"duration": cpu_duration, "success": cpu_success}
|
||||
|
||||
# Test 2: MLX Whisper (Apple Silicon optimized)
|
||||
mlx_options = create_mlx_whisper_options(model_size)
|
||||
mlx_duration, mlx_success = run_transcription_test(
|
||||
audio_file,
|
||||
mlx_options,
|
||||
AcceleratorDevice.MPS,
|
||||
f"MLX Whisper {model_size} (MPS)",
|
||||
)
|
||||
model_results["mlx"] = {"duration": mlx_duration, "success": mlx_success}
|
||||
|
||||
results[model_size] = model_results
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'#' * 80}")
|
||||
print("PERFORMANCE COMPARISON SUMMARY")
|
||||
print(f"{'#' * 80}")
|
||||
print(
|
||||
f"{'Model':<10} {'CPU (sec)':<12} {'MLX (sec)':<12} {'Speedup':<12} {'Status':<10}"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
for model_size, model_results in results.items():
|
||||
cpu_duration = model_results["cpu"]["duration"]
|
||||
mlx_duration = model_results["mlx"]["duration"]
|
||||
cpu_success = model_results["cpu"]["success"]
|
||||
mlx_success = model_results["mlx"]["success"]
|
||||
|
||||
if cpu_success and mlx_success:
|
||||
speedup = cpu_duration / mlx_duration
|
||||
status = "✅ Both OK"
|
||||
elif cpu_success:
|
||||
speedup = float("inf")
|
||||
status = "❌ MLX Failed"
|
||||
elif mlx_success:
|
||||
speedup = 0
|
||||
status = "❌ CPU Failed"
|
||||
else:
|
||||
speedup = 0
|
||||
status = "❌ Both Failed"
|
||||
|
||||
print(
|
||||
f"{model_size:<10} {cpu_duration:<12.2f} {mlx_duration:<12.2f} {speedup:<12.2f}x {status:<10}"
|
||||
)
|
||||
|
||||
# Calculate overall improvement
|
||||
successful_tests = [
|
||||
(r["cpu"]["duration"], r["mlx"]["duration"])
|
||||
for r in results.values()
|
||||
if r["cpu"]["success"] and r["mlx"]["success"]
|
||||
]
|
||||
|
||||
if successful_tests:
|
||||
avg_cpu = sum(cpu for cpu, mlx in successful_tests) / len(successful_tests)
|
||||
avg_mlx = sum(mlx for cpu, mlx in successful_tests) / len(successful_tests)
|
||||
avg_speedup = avg_cpu / avg_mlx
|
||||
|
||||
print("-" * 80)
|
||||
print(
|
||||
f"{'AVERAGE':<10} {avg_cpu:<12.2f} {avg_mlx:<12.2f} {avg_speedup:<12.2f}x {'Overall':<10}"
|
||||
)
|
||||
|
||||
print(f"\n🎯 MLX Whisper provides {avg_speedup:.1f}x average speedup over CPU!")
|
||||
else:
|
||||
print("\n❌ No successful comparisons available.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user