mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
finalised the first working ASR pipeline with Whisper
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
ed10d09936
commit
43239ff712
@ -31,11 +31,11 @@ from docling.backend.pdf_backend import PdfDocumentBackend
|
|||||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||||
from docling.datamodel.asr_model_specs import (
|
from docling.datamodel.asr_model_specs import (
|
||||||
WHISPER_TINY,
|
|
||||||
WHISPER_SMALL,
|
|
||||||
WHISPER_MEDIUM,
|
|
||||||
WHISPER_BASE,
|
WHISPER_BASE,
|
||||||
WHISPER_LARGE,
|
WHISPER_LARGE,
|
||||||
|
WHISPER_MEDIUM,
|
||||||
|
WHISPER_SMALL,
|
||||||
|
WHISPER_TINY,
|
||||||
WHISPER_TURBO,
|
WHISPER_TURBO,
|
||||||
AsrModelType,
|
AsrModelType,
|
||||||
)
|
)
|
||||||
@ -653,15 +653,15 @@ def convert( # noqa: C901
|
|||||||
elif asr_model == AsrModelType.WHISPER_BASE:
|
elif asr_model == AsrModelType.WHISPER_BASE:
|
||||||
pipeline_options.asr_options = WHISPER_BASE
|
pipeline_options.asr_options = WHISPER_BASE
|
||||||
elif asr_model == AsrModelType.WHISPER_LARGE:
|
elif asr_model == AsrModelType.WHISPER_LARGE:
|
||||||
pipeline_options.asr_options = WHISPER_LARGE
|
pipeline_options.asr_options = WHISPER_LARGE
|
||||||
elif asr_model == AsrModelType.WHISPER_TURBO:
|
elif asr_model == AsrModelType.WHISPER_TURBO:
|
||||||
pipeline_options.asr_options = WHISPER_TURBO
|
pipeline_options.asr_options = WHISPER_TURBO
|
||||||
else:
|
else:
|
||||||
_log.error(f"{asr_model} is not known")
|
_log.error(f"{asr_model} is not known")
|
||||||
raise ValueError(f"{asr_model} is not known")
|
raise ValueError(f"{asr_model} is not known")
|
||||||
|
|
||||||
_log.info(f"pipeline_options: {pipeline_options}")
|
_log.info(f"pipeline_options: {pipeline_options}")
|
||||||
|
|
||||||
audio_format_option = AudioFormatOption(
|
audio_format_option = AudioFormatOption(
|
||||||
pipeline_cls=AsrPipeline,
|
pipeline_cls=AsrPipeline,
|
||||||
pipeline_options=pipeline_options,
|
pipeline_options=pipeline_options,
|
||||||
|
@ -10,14 +10,13 @@ from docling.datamodel.pipeline_options_asr_model import (
|
|||||||
# AsrResponseFormat,
|
# AsrResponseFormat,
|
||||||
# ApiAsrOptions,
|
# ApiAsrOptions,
|
||||||
InferenceAsrFramework,
|
InferenceAsrFramework,
|
||||||
InlineAsrOptions,
|
InlineAsrNativeWhisperOptions,
|
||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
)
|
)
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# SmolDocling
|
WHISPER_TINY = InlineAsrNativeWhisperOptions(
|
||||||
WHISPER_TINY = InlineAsrOptions(
|
|
||||||
repo_id="tiny",
|
repo_id="tiny",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -28,8 +27,7 @@ WHISPER_TINY = InlineAsrOptions(
|
|||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
WHISPER_SMALL = InlineAsrNativeWhisperOptions(
|
||||||
WHISPER_SMALL = InlineAsrOptions(
|
|
||||||
repo_id="small",
|
repo_id="small",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -40,7 +38,7 @@ WHISPER_SMALL = InlineAsrOptions(
|
|||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
WHISPER_MEDIUM = InlineAsrOptions(
|
WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
|
||||||
repo_id="medium",
|
repo_id="medium",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -51,7 +49,7 @@ WHISPER_MEDIUM = InlineAsrOptions(
|
|||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
WHISPER_BASE = InlineAsrOptions(
|
WHISPER_BASE = InlineAsrNativeWhisperOptions(
|
||||||
repo_id="base",
|
repo_id="base",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -62,7 +60,7 @@ WHISPER_BASE = InlineAsrOptions(
|
|||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
WHISPER_LARGE = InlineAsrOptions(
|
WHISPER_LARGE = InlineAsrNativeWhisperOptions(
|
||||||
repo_id="large",
|
repo_id="large",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -73,7 +71,7 @@ WHISPER_LARGE = InlineAsrOptions(
|
|||||||
max_time_chunk=30.0,
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
WHISPER_TURBO = InlineAsrOptions(
|
WHISPER_TURBO = InlineAsrNativeWhisperOptions(
|
||||||
repo_id="turbo",
|
repo_id="turbo",
|
||||||
inference_framework=InferenceAsrFramework.WHISPER,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
@ -27,11 +27,8 @@ class InlineAsrOptions(BaseAsrOptions):
|
|||||||
|
|
||||||
repo_id: str
|
repo_id: str
|
||||||
|
|
||||||
inference_framework: InferenceAsrFramework
|
|
||||||
|
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
timestamps: bool = True
|
timestamps: bool = True
|
||||||
word_timestamps: bool = True
|
|
||||||
|
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
max_new_tokens: int = 256
|
max_new_tokens: int = 256
|
||||||
@ -44,25 +41,17 @@ class InlineAsrOptions(BaseAsrOptions):
|
|||||||
AcceleratorDevice.MPS,
|
AcceleratorDevice.MPS,
|
||||||
]
|
]
|
||||||
|
|
||||||
"""
|
|
||||||
repo_id: str
|
|
||||||
trust_remote_code: bool = False
|
|
||||||
load_in_8bit: bool = True
|
|
||||||
llm_int8_threshold: float = 6.0
|
|
||||||
quantized: bool = False
|
|
||||||
|
|
||||||
inference_framework: InferenceFramework
|
|
||||||
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
|
|
||||||
response_format: AsrResponseFormat
|
|
||||||
|
|
||||||
temperature: float = 0.0
|
|
||||||
stop_strings: List[str] = []
|
|
||||||
extra_generation_config: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
use_kv_cache: bool = True
|
|
||||||
max_new_tokens: int = 4096
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_cache_folder(self) -> str:
|
def repo_cache_folder(self) -> str:
|
||||||
return self.repo_id.replace("/", "--")
|
return self.repo_id.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
|
class InlineAsrNativeWhisperOptions(InlineAsrOptions):
|
||||||
|
inference_framework: InferenceAsrFramework = InferenceAsrFramework.WHISPER
|
||||||
|
|
||||||
|
language: str = "en"
|
||||||
|
supported_devices: List[AcceleratorDevice] = [
|
||||||
|
AcceleratorDevice.CPU,
|
||||||
|
AcceleratorDevice.CUDA,
|
||||||
|
]
|
||||||
|
word_timestamps: bool = True
|
||||||
|
@ -28,6 +28,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
AsrPipelineOptions,
|
AsrPipelineOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.pipeline_options_asr_model import (
|
from docling.datamodel.pipeline_options_asr_model import (
|
||||||
|
InlineAsrNativeWhisperOptions,
|
||||||
# AsrResponseFormat,
|
# AsrResponseFormat,
|
||||||
InlineAsrOptions,
|
InlineAsrOptions,
|
||||||
)
|
)
|
||||||
@ -36,6 +37,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.pipeline.base_pipeline import BasePipeline
|
from docling.pipeline.base_pipeline import BasePipeline
|
||||||
|
from docling.utils.accelerator_utils import decide_device
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
@ -96,17 +98,16 @@ class _NativeWhisperModel:
|
|||||||
enabled: bool,
|
enabled: bool,
|
||||||
artifacts_path: Optional[Path],
|
artifacts_path: Optional[Path],
|
||||||
accelerator_options: AcceleratorOptions,
|
accelerator_options: AcceleratorOptions,
|
||||||
asr_options: InlineAsrOptions,
|
asr_options: InlineAsrNativeWhisperOptions,
|
||||||
# model_name: str = "medium",
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Transcriber using native Whisper.
|
Transcriber using native Whisper.
|
||||||
"""
|
"""
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
|
|
||||||
_log.info(f"artifacts-path: {artifacts_path}")
|
_log.info(f"artifacts-path: {artifacts_path}")
|
||||||
_log.info(f"accelerator_options: {accelerator_options}")
|
_log.info(f"accelerator_options: {accelerator_options}")
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
try:
|
try:
|
||||||
import whisper # type: ignore
|
import whisper # type: ignore
|
||||||
@ -118,10 +119,25 @@ class _NativeWhisperModel:
|
|||||||
self.max_tokens = asr_options.max_new_tokens
|
self.max_tokens = asr_options.max_new_tokens
|
||||||
self.temperature = asr_options.temperature
|
self.temperature = asr_options.temperature
|
||||||
|
|
||||||
|
self.device = decide_device(
|
||||||
|
accelerator_options.device,
|
||||||
|
supported_devices=asr_options.supported_devices,
|
||||||
|
)
|
||||||
|
_log.info(f"Available device for Whisper: {self.device}")
|
||||||
|
|
||||||
self.model_name = asr_options.repo_id
|
self.model_name = asr_options.repo_id
|
||||||
_log.info(f"loading _NativeWhisperModel({self.model_name})")
|
_log.info(f"loading _NativeWhisperModel({self.model_name})")
|
||||||
self.model = whisper.load_model(self.model_name)
|
if artifacts_path is not None:
|
||||||
|
_log.info(f"loading {self.model_name} from {artifacts_path}")
|
||||||
|
self.model = whisper.load_model(
|
||||||
|
name=self.model_name,
|
||||||
|
device=self.device,
|
||||||
|
download_root=str(artifacts_path),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = whisper.load_model(
|
||||||
|
name=self.model_name, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
self.verbose = asr_options.verbose
|
self.verbose = asr_options.verbose
|
||||||
self.timestamps = asr_options.timestamps
|
self.timestamps = asr_options.timestamps
|
||||||
@ -189,15 +205,18 @@ class AsrPipeline(BasePipeline):
|
|||||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.pipeline_options.asr_options, InlineAsrOptions):
|
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
|
||||||
|
asr_options: InlineAsrNativeWhisperOptions = (
|
||||||
|
self.pipeline_options.asr_options
|
||||||
|
)
|
||||||
self._model = _NativeWhisperModel(
|
self._model = _NativeWhisperModel(
|
||||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
artifacts_path=artifacts_path,
|
artifacts_path=artifacts_path,
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
asr_options=pipeline_options.asr_options
|
asr_options=asr_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_log.error("")
|
_log.error(f"No model support for {self.pipeline_options.asr_options}")
|
||||||
|
|
||||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||||
status = ConversionStatus.SUCCESS
|
status = ConversionStatus.SUCCESS
|
||||||
|
Loading…
Reference in New Issue
Block a user