mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-10 21:58:15 +00:00
fix pre-commit checks and added proper type safety
This commit is contained in:
@@ -615,7 +615,17 @@ def convert( # noqa: C901
|
|||||||
# Auto-detect pipeline based on input file formats
|
# Auto-detect pipeline based on input file formats
|
||||||
if pipeline == ProcessingPipeline.STANDARD:
|
if pipeline == ProcessingPipeline.STANDARD:
|
||||||
# Check if any input files are audio files by extension
|
# Check if any input files are audio files by extension
|
||||||
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'}
|
audio_extensions = {
|
||||||
|
".mp3",
|
||||||
|
".wav",
|
||||||
|
".m4a",
|
||||||
|
".aac",
|
||||||
|
".ogg",
|
||||||
|
".flac",
|
||||||
|
".mp4",
|
||||||
|
".avi",
|
||||||
|
".mov",
|
||||||
|
}
|
||||||
for path in input_doc_paths:
|
for path in input_doc_paths:
|
||||||
if path.suffix.lower() in audio_extensions:
|
if path.suffix.lower() in audio_extensions:
|
||||||
pipeline = ProcessingPipeline.ASR
|
pipeline = ProcessingPipeline.ASR
|
||||||
|
|||||||
@@ -10,13 +10,14 @@ from docling.datamodel.pipeline_options_asr_model import (
|
|||||||
# AsrResponseFormat,
|
# AsrResponseFormat,
|
||||||
# ApiAsrOptions,
|
# ApiAsrOptions,
|
||||||
InferenceAsrFramework,
|
InferenceAsrFramework,
|
||||||
InlineAsrNativeWhisperOptions,
|
|
||||||
InlineAsrMlxWhisperOptions,
|
InlineAsrMlxWhisperOptions,
|
||||||
|
InlineAsrNativeWhisperOptions,
|
||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
)
|
)
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_whisper_tiny_model():
|
def _get_whisper_tiny_model():
|
||||||
"""
|
"""
|
||||||
Get the best Whisper Tiny model for the current hardware.
|
Get the best Whisper Tiny model for the current hardware.
|
||||||
@@ -27,6 +28,7 @@ def _get_whisper_tiny_model():
|
|||||||
# Check if MPS is available (Apple Silicon)
|
# Check if MPS is available (Apple Silicon)
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mps = False
|
has_mps = False
|
||||||
@@ -34,6 +36,7 @@ def _get_whisper_tiny_model():
|
|||||||
# Check if mlx-whisper is available
|
# Check if mlx-whisper is available
|
||||||
try:
|
try:
|
||||||
import mlx_whisper # type: ignore
|
import mlx_whisper # type: ignore
|
||||||
|
|
||||||
has_mlx_whisper = True
|
has_mlx_whisper = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mlx_whisper = False
|
has_mlx_whisper = False
|
||||||
@@ -66,6 +69,7 @@ def _get_whisper_tiny_model():
|
|||||||
# Create the model instance
|
# Create the model instance
|
||||||
WHISPER_TINY = _get_whisper_tiny_model()
|
WHISPER_TINY = _get_whisper_tiny_model()
|
||||||
|
|
||||||
|
|
||||||
def _get_whisper_small_model():
|
def _get_whisper_small_model():
|
||||||
"""
|
"""
|
||||||
Get the best Whisper Small model for the current hardware.
|
Get the best Whisper Small model for the current hardware.
|
||||||
@@ -76,6 +80,7 @@ def _get_whisper_small_model():
|
|||||||
# Check if MPS is available (Apple Silicon)
|
# Check if MPS is available (Apple Silicon)
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mps = False
|
has_mps = False
|
||||||
@@ -83,6 +88,7 @@ def _get_whisper_small_model():
|
|||||||
# Check if mlx-whisper is available
|
# Check if mlx-whisper is available
|
||||||
try:
|
try:
|
||||||
import mlx_whisper # type: ignore
|
import mlx_whisper # type: ignore
|
||||||
|
|
||||||
has_mlx_whisper = True
|
has_mlx_whisper = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mlx_whisper = False
|
has_mlx_whisper = False
|
||||||
@@ -115,6 +121,7 @@ def _get_whisper_small_model():
|
|||||||
# Create the model instance
|
# Create the model instance
|
||||||
WHISPER_SMALL = _get_whisper_small_model()
|
WHISPER_SMALL = _get_whisper_small_model()
|
||||||
|
|
||||||
|
|
||||||
def _get_whisper_medium_model():
|
def _get_whisper_medium_model():
|
||||||
"""
|
"""
|
||||||
Get the best Whisper Medium model for the current hardware.
|
Get the best Whisper Medium model for the current hardware.
|
||||||
@@ -125,6 +132,7 @@ def _get_whisper_medium_model():
|
|||||||
# Check if MPS is available (Apple Silicon)
|
# Check if MPS is available (Apple Silicon)
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mps = False
|
has_mps = False
|
||||||
@@ -132,6 +140,7 @@ def _get_whisper_medium_model():
|
|||||||
# Check if mlx-whisper is available
|
# Check if mlx-whisper is available
|
||||||
try:
|
try:
|
||||||
import mlx_whisper # type: ignore
|
import mlx_whisper # type: ignore
|
||||||
|
|
||||||
has_mlx_whisper = True
|
has_mlx_whisper = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mlx_whisper = False
|
has_mlx_whisper = False
|
||||||
@@ -164,6 +173,7 @@ def _get_whisper_medium_model():
|
|||||||
# Create the model instance
|
# Create the model instance
|
||||||
WHISPER_MEDIUM = _get_whisper_medium_model()
|
WHISPER_MEDIUM = _get_whisper_medium_model()
|
||||||
|
|
||||||
|
|
||||||
def _get_whisper_base_model():
|
def _get_whisper_base_model():
|
||||||
"""
|
"""
|
||||||
Get the best Whisper Base model for the current hardware.
|
Get the best Whisper Base model for the current hardware.
|
||||||
@@ -174,6 +184,7 @@ def _get_whisper_base_model():
|
|||||||
# Check if MPS is available (Apple Silicon)
|
# Check if MPS is available (Apple Silicon)
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mps = False
|
has_mps = False
|
||||||
@@ -181,6 +192,7 @@ def _get_whisper_base_model():
|
|||||||
# Check if mlx-whisper is available
|
# Check if mlx-whisper is available
|
||||||
try:
|
try:
|
||||||
import mlx_whisper # type: ignore
|
import mlx_whisper # type: ignore
|
||||||
|
|
||||||
has_mlx_whisper = True
|
has_mlx_whisper = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mlx_whisper = False
|
has_mlx_whisper = False
|
||||||
@@ -213,6 +225,7 @@ def _get_whisper_base_model():
|
|||||||
# Create the model instance
|
# Create the model instance
|
||||||
WHISPER_BASE = _get_whisper_base_model()
|
WHISPER_BASE = _get_whisper_base_model()
|
||||||
|
|
||||||
|
|
||||||
def _get_whisper_large_model():
|
def _get_whisper_large_model():
|
||||||
"""
|
"""
|
||||||
Get the best Whisper Large model for the current hardware.
|
Get the best Whisper Large model for the current hardware.
|
||||||
@@ -223,6 +236,7 @@ def _get_whisper_large_model():
|
|||||||
# Check if MPS is available (Apple Silicon)
|
# Check if MPS is available (Apple Silicon)
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mps = False
|
has_mps = False
|
||||||
@@ -230,6 +244,7 @@ def _get_whisper_large_model():
|
|||||||
# Check if mlx-whisper is available
|
# Check if mlx-whisper is available
|
||||||
try:
|
try:
|
||||||
import mlx_whisper # type: ignore
|
import mlx_whisper # type: ignore
|
||||||
|
|
||||||
has_mlx_whisper = True
|
has_mlx_whisper = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mlx_whisper = False
|
has_mlx_whisper = False
|
||||||
@@ -262,6 +277,7 @@ def _get_whisper_large_model():
|
|||||||
# Create the model instance
|
# Create the model instance
|
||||||
WHISPER_LARGE = _get_whisper_large_model()
|
WHISPER_LARGE = _get_whisper_large_model()
|
||||||
|
|
||||||
|
|
||||||
def _get_whisper_turbo_model():
|
def _get_whisper_turbo_model():
|
||||||
"""
|
"""
|
||||||
Get the best Whisper Turbo model for the current hardware.
|
Get the best Whisper Turbo model for the current hardware.
|
||||||
@@ -272,6 +288,7 @@ def _get_whisper_turbo_model():
|
|||||||
# Check if MPS is available (Apple Silicon)
|
# Check if MPS is available (Apple Silicon)
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mps = False
|
has_mps = False
|
||||||
@@ -279,6 +296,7 @@ def _get_whisper_turbo_model():
|
|||||||
# Check if mlx-whisper is available
|
# Check if mlx-whisper is available
|
||||||
try:
|
try:
|
||||||
import mlx_whisper # type: ignore
|
import mlx_whisper # type: ignore
|
||||||
|
|
||||||
has_mlx_whisper = True
|
has_mlx_whisper = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_mlx_whisper = False
|
has_mlx_whisper = False
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class InlineAsrMlxWhisperOptions(InlineAsrOptions):
|
|||||||
|
|
||||||
Uses mlx-whisper library for efficient inference on Apple Silicon devices.
|
Uses mlx-whisper library for efficient inference on Apple Silicon devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
|
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
|
||||||
|
|
||||||
language: str = "en"
|
language: str = "en"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import re
|
|||||||
import tempfile
|
import tempfile
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union, cast
|
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||||
|
|
||||||
from docling_core.types.doc import DoclingDocument, DocumentOrigin
|
from docling_core.types.doc import DoclingDocument, DocumentOrigin
|
||||||
|
|
||||||
@@ -32,8 +32,8 @@ from docling.datamodel.pipeline_options import (
|
|||||||
AsrPipelineOptions,
|
AsrPipelineOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.pipeline_options_asr_model import (
|
from docling.datamodel.pipeline_options_asr_model import (
|
||||||
InlineAsrNativeWhisperOptions,
|
|
||||||
InlineAsrMlxWhisperOptions,
|
InlineAsrMlxWhisperOptions,
|
||||||
|
InlineAsrNativeWhisperOptions,
|
||||||
# AsrResponseFormat,
|
# AsrResponseFormat,
|
||||||
InlineAsrOptions,
|
InlineAsrOptions,
|
||||||
)
|
)
|
||||||
@@ -334,7 +334,7 @@ class _MlxWhisperModel:
|
|||||||
start_time=segment.get("start"),
|
start_time=segment.get("start"),
|
||||||
end_time=segment.get("end"),
|
end_time=segment.get("end"),
|
||||||
text=segment.get("text", "").strip(),
|
text=segment.get("text", "").strip(),
|
||||||
words=[]
|
words=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add word-level timestamps if available
|
# Add word-level timestamps if available
|
||||||
@@ -359,26 +359,27 @@ class AsrPipeline(BasePipeline):
|
|||||||
self.keep_backend = True
|
self.keep_backend = True
|
||||||
|
|
||||||
self.pipeline_options: AsrPipelineOptions = pipeline_options
|
self.pipeline_options: AsrPipelineOptions = pipeline_options
|
||||||
|
self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
|
||||||
|
|
||||||
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
|
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
|
||||||
asr_options: InlineAsrNativeWhisperOptions = (
|
native_asr_options: InlineAsrNativeWhisperOptions = (
|
||||||
self.pipeline_options.asr_options
|
self.pipeline_options.asr_options
|
||||||
)
|
)
|
||||||
self._model = _NativeWhisperModel(
|
self._model = _NativeWhisperModel(
|
||||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
artifacts_path=self.artifacts_path,
|
artifacts_path=self.artifacts_path,
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
asr_options=asr_options,
|
asr_options=native_asr_options,
|
||||||
)
|
)
|
||||||
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
|
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
|
||||||
asr_options: InlineAsrMlxWhisperOptions = (
|
mlx_asr_options: InlineAsrMlxWhisperOptions = (
|
||||||
self.pipeline_options.asr_options
|
self.pipeline_options.asr_options
|
||||||
)
|
)
|
||||||
self._model = _MlxWhisperModel(
|
self._model = _MlxWhisperModel(
|
||||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
artifacts_path=self.artifacts_path,
|
artifacts_path=self.artifacts_path,
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
asr_options=asr_options,
|
asr_options=mlx_asr_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_log.error(f"No model support for {self.pipeline_options.asr_options}")
|
_log.error(f"No model support for {self.pipeline_options.asr_options}")
|
||||||
|
|||||||
22
docs/examples/mlx_whisper_example.py
vendored
22
docs/examples/mlx_whisper_example.py
vendored
@@ -12,19 +12,19 @@ from pathlib import Path
|
|||||||
# Add the repository root to the path so we can import docling
|
# Add the repository root to the path so we can import docling
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||||
from docling.datamodel.asr_model_specs import (
|
from docling.datamodel.asr_model_specs import (
|
||||||
WHISPER_TINY,
|
|
||||||
WHISPER_BASE,
|
WHISPER_BASE,
|
||||||
WHISPER_SMALL,
|
|
||||||
WHISPER_MEDIUM,
|
|
||||||
WHISPER_LARGE,
|
WHISPER_LARGE,
|
||||||
|
WHISPER_MEDIUM,
|
||||||
|
WHISPER_SMALL,
|
||||||
|
WHISPER_TINY,
|
||||||
WHISPER_TURBO,
|
WHISPER_TURBO,
|
||||||
)
|
)
|
||||||
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
|
||||||
from docling.datamodel.base_models import InputFormat
|
from docling.datamodel.base_models import InputFormat
|
||||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||||
|
from docling.document_converter import AudioFormatOption, DocumentConverter
|
||||||
from docling.pipeline.asr_pipeline import AsrPipeline
|
from docling.pipeline.asr_pipeline import AsrPipeline
|
||||||
from docling.document_converter import DocumentConverter, AudioFormatOption
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"):
|
def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "base"):
|
||||||
@@ -51,7 +51,9 @@ def transcribe_audio_with_mlx_whisper(audio_file_path: str, model_size: str = "b
|
|||||||
}
|
}
|
||||||
|
|
||||||
if model_size not in model_map:
|
if model_size not in model_map:
|
||||||
raise ValueError(f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}")
|
raise ValueError(
|
||||||
|
f"Invalid model size: {model_size}. Choose from: {list(model_map.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
asr_options = model_map[model_size]
|
asr_options = model_map[model_size]
|
||||||
|
|
||||||
@@ -105,10 +107,14 @@ def main():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...")
|
print(f"Transcribing '{audio_file_path}' using Whisper {model_size} model...")
|
||||||
print("Note: MLX optimization is automatically used on Apple Silicon when available.")
|
print(
|
||||||
|
"Note: MLX optimization is automatically used on Apple Silicon when available."
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
transcribed_text = transcribe_audio_with_mlx_whisper(audio_file_path, model_size)
|
transcribed_text = transcribe_audio_with_mlx_whisper(
|
||||||
|
audio_file_path, model_size
|
||||||
|
)
|
||||||
|
|
||||||
print("Transcription Result:")
|
print("Transcription Result:")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|||||||
@@ -1,25 +1,27 @@
|
|||||||
"""
|
"""
|
||||||
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
Test MLX Whisper integration for Apple Silicon ASR pipeline.
|
||||||
"""
|
"""
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||||
from docling.datamodel.asr_model_specs import (
|
from docling.datamodel.asr_model_specs import (
|
||||||
WHISPER_TINY,
|
|
||||||
WHISPER_BASE,
|
WHISPER_BASE,
|
||||||
WHISPER_SMALL,
|
|
||||||
WHISPER_MEDIUM,
|
|
||||||
WHISPER_LARGE,
|
WHISPER_LARGE,
|
||||||
|
WHISPER_MEDIUM,
|
||||||
|
WHISPER_SMALL,
|
||||||
|
WHISPER_TINY,
|
||||||
WHISPER_TURBO,
|
WHISPER_TURBO,
|
||||||
)
|
)
|
||||||
|
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
||||||
from docling.datamodel.pipeline_options_asr_model import (
|
from docling.datamodel.pipeline_options_asr_model import (
|
||||||
InferenceAsrFramework,
|
InferenceAsrFramework,
|
||||||
InlineAsrMlxWhisperOptions,
|
InlineAsrMlxWhisperOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.accelerator_options import AcceleratorOptions, AcceleratorDevice
|
|
||||||
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
|
||||||
from docling.datamodel.pipeline_options import AsrPipelineOptions
|
|
||||||
|
|
||||||
|
|
||||||
class TestMlxWhisperIntegration:
|
class TestMlxWhisperIntegration:
|
||||||
@@ -47,16 +49,16 @@ class TestMlxWhisperIntegration:
|
|||||||
# these models would automatically use MLX
|
# these models would automatically use MLX
|
||||||
|
|
||||||
# Check that the models exist and have the correct structure
|
# Check that the models exist and have the correct structure
|
||||||
assert hasattr(WHISPER_TURBO, 'inference_framework')
|
assert hasattr(WHISPER_TURBO, "inference_framework")
|
||||||
assert hasattr(WHISPER_TURBO, 'repo_id')
|
assert hasattr(WHISPER_TURBO, "repo_id")
|
||||||
|
|
||||||
assert hasattr(WHISPER_BASE, 'inference_framework')
|
assert hasattr(WHISPER_BASE, "inference_framework")
|
||||||
assert hasattr(WHISPER_BASE, 'repo_id')
|
assert hasattr(WHISPER_BASE, "repo_id")
|
||||||
|
|
||||||
assert hasattr(WHISPER_SMALL, 'inference_framework')
|
assert hasattr(WHISPER_SMALL, "inference_framework")
|
||||||
assert hasattr(WHISPER_SMALL, 'repo_id')
|
assert hasattr(WHISPER_SMALL, "repo_id")
|
||||||
|
|
||||||
@patch('builtins.__import__')
|
@patch("builtins.__import__")
|
||||||
def test_mlx_whisper_model_initialization(self, mock_import):
|
def test_mlx_whisper_model_initialization(self, mock_import):
|
||||||
"""Test MLX Whisper model initialization."""
|
"""Test MLX Whisper model initialization."""
|
||||||
# Mock the mlx_whisper import
|
# Mock the mlx_whisper import
|
||||||
@@ -102,7 +104,10 @@ class TestMlxWhisperIntegration:
|
|||||||
compression_ratio_threshold=2.4,
|
compression_ratio_threshold=2.4,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('builtins.__import__', side_effect=ImportError("No module named 'mlx_whisper'")):
|
with patch(
|
||||||
|
"builtins.__import__",
|
||||||
|
side_effect=ImportError("No module named 'mlx_whisper'"),
|
||||||
|
):
|
||||||
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
|
||||||
_MlxWhisperModel(
|
_MlxWhisperModel(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
@@ -111,7 +116,7 @@ class TestMlxWhisperIntegration:
|
|||||||
asr_options=asr_options,
|
asr_options=asr_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('builtins.__import__')
|
@patch("builtins.__import__")
|
||||||
def test_mlx_whisper_transcribe(self, mock_import):
|
def test_mlx_whisper_transcribe(self, mock_import):
|
||||||
"""Test MLX Whisper transcription method."""
|
"""Test MLX Whisper transcription method."""
|
||||||
# Mock the mlx_whisper module and its transcribe function
|
# Mock the mlx_whisper module and its transcribe function
|
||||||
@@ -128,7 +133,7 @@ class TestMlxWhisperIntegration:
|
|||||||
"words": [
|
"words": [
|
||||||
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
{"start": 0.0, "end": 0.5, "word": "Hello"},
|
||||||
{"start": 0.5, "end": 1.0, "word": "world"},
|
{"start": 0.5, "end": 1.0, "word": "world"},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -178,7 +183,7 @@ class TestMlxWhisperIntegration:
|
|||||||
compression_ratio_threshold=2.4,
|
compression_ratio_threshold=2.4,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('builtins.__import__')
|
@patch("builtins.__import__")
|
||||||
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
|
||||||
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
|
||||||
# Mock the mlx_whisper import
|
# Mock the mlx_whisper import
|
||||||
|
|||||||
Reference in New Issue
Block a user