updating with asr_options

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-06-18 06:23:33 +02:00
parent e5fd579861
commit ed10d09936
5 changed files with 174 additions and 39 deletions

View File

@ -32,6 +32,11 @@ from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import ( from docling.datamodel.asr_model_specs import (
WHISPER_TINY, WHISPER_TINY,
WHISPER_SMALL,
WHISPER_MEDIUM,
WHISPER_BASE,
WHISPER_LARGE,
WHISPER_TURBO,
AsrModelType, AsrModelType,
) )
from docling.datamodel.base_models import ( from docling.datamodel.base_models import (
@ -641,9 +646,21 @@ def convert( # noqa: C901
if asr_model == AsrModelType.WHISPER_TINY: if asr_model == AsrModelType.WHISPER_TINY:
pipeline_options.asr_options = WHISPER_TINY pipeline_options.asr_options = WHISPER_TINY
elif asr_model == AsrModelType.WHISPER_SMALL:
pipeline_options.asr_options = WHISPER_SMALL
elif asr_model == AsrModelType.WHISPER_MEDIUM:
pipeline_options.asr_options = WHISPER_MEDIUM
elif asr_model == AsrModelType.WHISPER_BASE:
pipeline_options.asr_options = WHISPER_BASE
elif asr_model == AsrModelType.WHISPER_LARGE:
pipeline_options.asr_options = WHISPER_LARGE
elif asr_model == AsrModelType.WHISPER_TURBO:
pipeline_options.asr_options = WHISPER_TURBO
else: else:
_log.warning("falling back in base ASR model: WHISPER_TINY") _log.error(f"{asr_model} is not known")
pipeline_options.asr_options = WHISPER_TINY raise ValueError(f"{asr_model} is not known")
_log.info(f"pipeline_options: {pipeline_options}")
audio_format_option = AudioFormatOption( audio_format_option = AudioFormatOption(
pipeline_cls=AsrPipeline, pipeline_cls=AsrPipeline,

View File

@ -7,9 +7,9 @@ from pydantic import (
from docling.datamodel.accelerator_options import AcceleratorDevice from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
AsrResponseFormat, # AsrResponseFormat,
# ApiAsrOptions, # ApiAsrOptions,
InferenceFramework, InferenceAsrFramework,
InlineAsrOptions, InlineAsrOptions,
TransformersModelType, TransformersModelType,
) )
@ -18,11 +18,77 @@ _log = logging.getLogger(__name__)
# SmolDocling # SmolDocling
WHISPER_TINY = InlineAsrOptions( WHISPER_TINY = InlineAsrOptions(
repo_id="openai/whisper-tiny", repo_id="tiny",
inference_framework=InferenceFramework.TRANSFORMERS, inference_framework=InferenceAsrFramework.WHISPER,
response_format=AsrResponseFormat.WHISPER, verbose=True,
timestamps=True,
word_timestamps=True,
temperatue=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
WHISPER_SMALL = InlineAsrOptions(
repo_id="small",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperatue=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
WHISPER_MEDIUM = InlineAsrOptions(
repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperatue=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
WHISPER_BASE = InlineAsrOptions(
repo_id="base",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperatue=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
WHISPER_LARGE = InlineAsrOptions(
repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperatue=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
)
WHISPER_TURBO = InlineAsrOptions(
repo_id="turbo",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
timestamps=True,
word_timestamps=True,
temperatue=0.0,
max_new_tokens=256,
max_time_chunk=30.0,
) )
class AsrModelType(str, Enum): class AsrModelType(str, Enum):
WHISPER_TINY = "whisper_tiny" WHISPER_TINY = "whisper_tiny"
WHISPER_SMALL = "whisper_small"
WHISPER_MEDIUM = "whisper_medium"
WHISPER_BASE = "whisper_base"
WHISPER_LARGE = "whisper_large"
WHISPER_TURBO = "whisper_turbo"

View File

@ -49,7 +49,7 @@ class InputFormat(str, Enum):
XML_USPTO = "xml_uspto" XML_USPTO = "xml_uspto"
XML_JATS = "xml_jats" XML_JATS = "xml_jats"
JSON_DOCLING = "json_docling" JSON_DOCLING = "json_docling"
AUDIO = "wav" AUDIO = "audio"
class OutputFormat(str, Enum): class OutputFormat(str, Enum):
@ -74,7 +74,7 @@ FormatToExtensions: Dict[InputFormat, List[str]] = {
InputFormat.XLSX: ["xlsx", "xlsm"], InputFormat.XLSX: ["xlsx", "xlsm"],
InputFormat.XML_USPTO: ["xml", "txt"], InputFormat.XML_USPTO: ["xml", "txt"],
InputFormat.JSON_DOCLING: ["json"], InputFormat.JSON_DOCLING: ["json"],
InputFormat.AUDIO: ["wav"], InputFormat.AUDIO: ["wav", "mp3"],
} }
FormatToMimeType: Dict[InputFormat, List[str]] = { FormatToMimeType: Dict[InputFormat, List[str]] = {

View File

@ -6,7 +6,7 @@ from typing_extensions import deprecated
from docling.datamodel.accelerator_options import AcceleratorDevice from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.pipeline_options_vlm_model import ( from docling.datamodel.pipeline_options_vlm_model import (
InferenceFramework, # InferenceFramework,
TransformersModelType, TransformersModelType,
) )
@ -16,13 +16,35 @@ class BaseAsrOptions(BaseModel):
# prompt: str # prompt: str
class AsrResponseFormat(str, Enum): class InferenceAsrFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
WHISPER = "whisper" WHISPER = "whisper"
class InlineAsrOptions(BaseAsrOptions): class InlineAsrOptions(BaseAsrOptions):
kind: Literal["inline_model_options"] = "inline_model_options" kind: Literal["inline_model_options"] = "inline_model_options"
repo_id: str
inference_framework: InferenceAsrFramework
verbose: bool = False
timestamps: bool = True
word_timestamps: bool = True
temperature: float = 0.0
max_new_tokens: int = 256
max_time_chunk: float = 30.0
torch_dtype: Optional[str] = None
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
]
"""
repo_id: str repo_id: str
trust_remote_code: bool = False trust_remote_code: bool = False
load_in_8bit: bool = True load_in_8bit: bool = True
@ -33,19 +55,13 @@ class InlineAsrOptions(BaseAsrOptions):
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: AsrResponseFormat response_format: AsrResponseFormat
torch_dtype: Optional[str] = None
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
]
temperature: float = 0.0 temperature: float = 0.0
stop_strings: List[str] = [] stop_strings: List[str] = []
extra_generation_config: Dict[str, Any] = {} extra_generation_config: Dict[str, Any] = {}
use_kv_cache: bool = True use_kv_cache: bool = True
max_new_tokens: int = 4096 max_new_tokens: int = 4096
"""
@property @property
def repo_cache_folder(self) -> str: def repo_cache_folder(self) -> str:

View File

@ -5,18 +5,21 @@ from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union, cast from typing import List, Optional, Union, cast
import whisper # type: ignore # import whisper # type: ignore
# import librosa # import librosa
# import numpy as np # import numpy as np
# import soundfile as sf # type: ignore # import soundfile as sf # type: ignore
from docling_core.types.doc.labels import DocItemLabel from docling_core.types.doc.labels import DocItemLabel
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
# from pydub import AudioSegment # type: ignore
# from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.audio_backend import AudioBackend from docling.backend.audio_backend import AudioBackend
# from pydub import AudioSegment # type: ignore
# from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import ( from docling.datamodel.base_models import (
ConversionStatus, ConversionStatus,
) )
@ -25,7 +28,7 @@ from docling.datamodel.pipeline_options import (
AsrPipelineOptions, AsrPipelineOptions,
) )
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
AsrResponseFormat, # AsrResponseFormat,
InlineAsrOptions, InlineAsrOptions,
) )
from docling.datamodel.pipeline_options_vlm_model import ( from docling.datamodel.pipeline_options_vlm_model import (
@ -88,15 +91,41 @@ class _ConversationItem(BaseModel):
class _NativeWhisperModel: class _NativeWhisperModel:
def __init__(self, model_name: str = "medium"): def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
asr_options: InlineAsrOptions,
# model_name: str = "medium",
):
""" """
Transcriber using native Whisper. Transcriber using native Whisper.
""" """
self.enabled = enabled
self.model = whisper.load_model(model_name) _log.info(f"artifacts-path: {artifacts_path}")
_log.info(f"accelerator_options: {accelerator_options}")
self.verbose = True if self.enabled:
self.word_timestamps = True try:
import whisper # type: ignore
except ImportError:
raise ImportError(
"whisper is not installed. Please install it via `pip install openai-whisper`."
)
self.asr_options = asr_options
self.max_tokens = asr_options.max_new_tokens
self.temperature = asr_options.temperature
self.model_name = asr_options.repo_id
_log.info(f"loading _NativeWhisperModel({self.model_name})")
self.model = whisper.load_model(self.model_name)
self.verbose = asr_options.verbose
self.timestamps = asr_options.timestamps
self.word_timestamps = asr_options.word_timestamps
def run(self, conv_res: ConversionResult) -> ConversionResult: def run(self, conv_res: ConversionResult) -> ConversionResult:
audio_path: Path = Path(conv_res.input.file).resolve() audio_path: Path = Path(conv_res.input.file).resolve()
@ -126,15 +155,16 @@ class _NativeWhisperModel:
item = _ConversationItem( item = _ConversationItem(
start_time=_["start"], end_time=_["end"], text=_["text"], words=[] start_time=_["start"], end_time=_["end"], text=_["text"], words=[]
) )
item.words = [] if "words" in _ and self.word_timestamps:
for __ in _["words"]: item.words = []
item.words.append( for __ in _["words"]:
_ConversationWord( item.words.append(
start_time=__["start"], _ConversationWord(
end_time=__["end"], start_time=__["start"],
text=__["word"], end_time=__["end"],
text=__["word"],
)
) )
)
convo.append(item) convo.append(item)
return convo return convo
@ -159,8 +189,15 @@ class AsrPipeline(BasePipeline):
"When defined, it must point to a folder containing all models required by the pipeline." "When defined, it must point to a folder containing all models required by the pipeline."
) )
# self._model = _WhisperModel() if isinstance(self.pipeline_options.asr_options, InlineAsrOptions):
self._model = _NativeWhisperModel() self._model = _NativeWhisperModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
asr_options=pipeline_options.asr_options
)
else:
_log.error("")
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
status = ConversionStatus.SUCCESS status = ConversionStatus.SUCCESS
@ -171,10 +208,9 @@ class AsrPipeline(BasePipeline):
return AsrPipelineOptions() return AsrPipelineOptions()
def _build_document(self, conv_res: ConversionResult) -> ConversionResult: def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
_log.info(f"start _build_document in AsrPipeline: {conv_res.input.file}")
with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT): with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT):
_log.info(f"do something: {conv_res.input.file}")
self._model.run(conv_res=conv_res) self._model.run(conv_res=conv_res)
_log.info(f"finished doing something: {conv_res.input.file}")
return conv_res return conv_res