mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
updating with asr_options
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
e5fd579861
commit
ed10d09936
@ -32,6 +32,11 @@ from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
|||||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||||
from docling.datamodel.asr_model_specs import (
|
from docling.datamodel.asr_model_specs import (
|
||||||
WHISPER_TINY,
|
WHISPER_TINY,
|
||||||
|
WHISPER_SMALL,
|
||||||
|
WHISPER_MEDIUM,
|
||||||
|
WHISPER_BASE,
|
||||||
|
WHISPER_LARGE,
|
||||||
|
WHISPER_TURBO,
|
||||||
AsrModelType,
|
AsrModelType,
|
||||||
)
|
)
|
||||||
from docling.datamodel.base_models import (
|
from docling.datamodel.base_models import (
|
||||||
@ -641,10 +646,22 @@ def convert( # noqa: C901
|
|||||||
|
|
||||||
if asr_model == AsrModelType.WHISPER_TINY:
|
if asr_model == AsrModelType.WHISPER_TINY:
|
||||||
pipeline_options.asr_options = WHISPER_TINY
|
pipeline_options.asr_options = WHISPER_TINY
|
||||||
|
elif asr_model == AsrModelType.WHISPER_SMALL:
|
||||||
|
pipeline_options.asr_options = WHISPER_SMALL
|
||||||
|
elif asr_model == AsrModelType.WHISPER_MEDIUM:
|
||||||
|
pipeline_options.asr_options = WHISPER_MEDIUM
|
||||||
|
elif asr_model == AsrModelType.WHISPER_BASE:
|
||||||
|
pipeline_options.asr_options = WHISPER_BASE
|
||||||
|
elif asr_model == AsrModelType.WHISPER_LARGE:
|
||||||
|
pipeline_options.asr_options = WHISPER_LARGE
|
||||||
|
elif asr_model == AsrModelType.WHISPER_TURBO:
|
||||||
|
pipeline_options.asr_options = WHISPER_TURBO
|
||||||
else:
|
else:
|
||||||
_log.warning("falling back in base ASR model: WHISPER_TINY")
|
_log.error(f"{asr_model} is not known")
|
||||||
pipeline_options.asr_options = WHISPER_TINY
|
raise ValueError(f"{asr_model} is not known")
|
||||||
|
|
||||||
|
_log.info(f"pipeline_options: {pipeline_options}")
|
||||||
|
|
||||||
audio_format_option = AudioFormatOption(
|
audio_format_option = AudioFormatOption(
|
||||||
pipeline_cls=AsrPipeline,
|
pipeline_cls=AsrPipeline,
|
||||||
pipeline_options=pipeline_options,
|
pipeline_options=pipeline_options,
|
||||||
|
@ -7,9 +7,9 @@ from pydantic import (
|
|||||||
|
|
||||||
from docling.datamodel.accelerator_options import AcceleratorDevice
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||||
from docling.datamodel.pipeline_options_asr_model import (
|
from docling.datamodel.pipeline_options_asr_model import (
|
||||||
AsrResponseFormat,
|
# AsrResponseFormat,
|
||||||
# ApiAsrOptions,
|
# ApiAsrOptions,
|
||||||
InferenceFramework,
|
InferenceAsrFramework,
|
||||||
InlineAsrOptions,
|
InlineAsrOptions,
|
||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
)
|
)
|
||||||
@ -18,11 +18,77 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# SmolDocling
|
# SmolDocling
|
||||||
WHISPER_TINY = InlineAsrOptions(
|
WHISPER_TINY = InlineAsrOptions(
|
||||||
repo_id="openai/whisper-tiny",
|
repo_id="tiny",
|
||||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
response_format=AsrResponseFormat.WHISPER,
|
verbose=True,
|
||||||
|
timestamps=True,
|
||||||
|
word_timestamps=True,
|
||||||
|
temperatue=0.0,
|
||||||
|
max_new_tokens=256,
|
||||||
|
max_time_chunk=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
WHISPER_SMALL = InlineAsrOptions(
|
||||||
|
repo_id="small",
|
||||||
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
|
verbose=True,
|
||||||
|
timestamps=True,
|
||||||
|
word_timestamps=True,
|
||||||
|
temperatue=0.0,
|
||||||
|
max_new_tokens=256,
|
||||||
|
max_time_chunk=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
WHISPER_MEDIUM = InlineAsrOptions(
|
||||||
|
repo_id="medium",
|
||||||
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
|
verbose=True,
|
||||||
|
timestamps=True,
|
||||||
|
word_timestamps=True,
|
||||||
|
temperatue=0.0,
|
||||||
|
max_new_tokens=256,
|
||||||
|
max_time_chunk=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
WHISPER_BASE = InlineAsrOptions(
|
||||||
|
repo_id="base",
|
||||||
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
|
verbose=True,
|
||||||
|
timestamps=True,
|
||||||
|
word_timestamps=True,
|
||||||
|
temperatue=0.0,
|
||||||
|
max_new_tokens=256,
|
||||||
|
max_time_chunk=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
WHISPER_LARGE = InlineAsrOptions(
|
||||||
|
repo_id="large",
|
||||||
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
|
verbose=True,
|
||||||
|
timestamps=True,
|
||||||
|
word_timestamps=True,
|
||||||
|
temperatue=0.0,
|
||||||
|
max_new_tokens=256,
|
||||||
|
max_time_chunk=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
WHISPER_TURBO = InlineAsrOptions(
|
||||||
|
repo_id="turbo",
|
||||||
|
inference_framework=InferenceAsrFramework.WHISPER,
|
||||||
|
verbose=True,
|
||||||
|
timestamps=True,
|
||||||
|
word_timestamps=True,
|
||||||
|
temperatue=0.0,
|
||||||
|
max_new_tokens=256,
|
||||||
|
max_time_chunk=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsrModelType(str, Enum):
|
class AsrModelType(str, Enum):
|
||||||
WHISPER_TINY = "whisper_tiny"
|
WHISPER_TINY = "whisper_tiny"
|
||||||
|
WHISPER_SMALL = "whisper_small"
|
||||||
|
WHISPER_MEDIUM = "whisper_medium"
|
||||||
|
WHISPER_BASE = "whisper_base"
|
||||||
|
WHISPER_LARGE = "whisper_large"
|
||||||
|
WHISPER_TURBO = "whisper_turbo"
|
||||||
|
@ -49,7 +49,7 @@ class InputFormat(str, Enum):
|
|||||||
XML_USPTO = "xml_uspto"
|
XML_USPTO = "xml_uspto"
|
||||||
XML_JATS = "xml_jats"
|
XML_JATS = "xml_jats"
|
||||||
JSON_DOCLING = "json_docling"
|
JSON_DOCLING = "json_docling"
|
||||||
AUDIO = "wav"
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
class OutputFormat(str, Enum):
|
class OutputFormat(str, Enum):
|
||||||
@ -74,7 +74,7 @@ FormatToExtensions: Dict[InputFormat, List[str]] = {
|
|||||||
InputFormat.XLSX: ["xlsx", "xlsm"],
|
InputFormat.XLSX: ["xlsx", "xlsm"],
|
||||||
InputFormat.XML_USPTO: ["xml", "txt"],
|
InputFormat.XML_USPTO: ["xml", "txt"],
|
||||||
InputFormat.JSON_DOCLING: ["json"],
|
InputFormat.JSON_DOCLING: ["json"],
|
||||||
InputFormat.AUDIO: ["wav"],
|
InputFormat.AUDIO: ["wav", "mp3"],
|
||||||
}
|
}
|
||||||
|
|
||||||
FormatToMimeType: Dict[InputFormat, List[str]] = {
|
FormatToMimeType: Dict[InputFormat, List[str]] = {
|
||||||
|
@ -6,7 +6,7 @@ from typing_extensions import deprecated
|
|||||||
|
|
||||||
from docling.datamodel.accelerator_options import AcceleratorDevice
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
InferenceFramework,
|
# InferenceFramework,
|
||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,13 +16,35 @@ class BaseAsrOptions(BaseModel):
|
|||||||
# prompt: str
|
# prompt: str
|
||||||
|
|
||||||
|
|
||||||
class AsrResponseFormat(str, Enum):
|
class InferenceAsrFramework(str, Enum):
|
||||||
|
MLX = "mlx"
|
||||||
|
TRANSFORMERS = "transformers"
|
||||||
WHISPER = "whisper"
|
WHISPER = "whisper"
|
||||||
|
|
||||||
|
|
||||||
class InlineAsrOptions(BaseAsrOptions):
|
class InlineAsrOptions(BaseAsrOptions):
|
||||||
kind: Literal["inline_model_options"] = "inline_model_options"
|
kind: Literal["inline_model_options"] = "inline_model_options"
|
||||||
|
|
||||||
|
repo_id: str
|
||||||
|
|
||||||
|
inference_framework: InferenceAsrFramework
|
||||||
|
|
||||||
|
verbose: bool = False
|
||||||
|
timestamps: bool = True
|
||||||
|
word_timestamps: bool = True
|
||||||
|
|
||||||
|
temperature: float = 0.0
|
||||||
|
max_new_tokens: int = 256
|
||||||
|
max_time_chunk: float = 30.0
|
||||||
|
|
||||||
|
torch_dtype: Optional[str] = None
|
||||||
|
supported_devices: List[AcceleratorDevice] = [
|
||||||
|
AcceleratorDevice.CPU,
|
||||||
|
AcceleratorDevice.CUDA,
|
||||||
|
AcceleratorDevice.MPS,
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
repo_id: str
|
repo_id: str
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
load_in_8bit: bool = True
|
load_in_8bit: bool = True
|
||||||
@ -33,19 +55,13 @@ class InlineAsrOptions(BaseAsrOptions):
|
|||||||
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
|
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
|
||||||
response_format: AsrResponseFormat
|
response_format: AsrResponseFormat
|
||||||
|
|
||||||
torch_dtype: Optional[str] = None
|
|
||||||
supported_devices: List[AcceleratorDevice] = [
|
|
||||||
AcceleratorDevice.CPU,
|
|
||||||
AcceleratorDevice.CUDA,
|
|
||||||
AcceleratorDevice.MPS,
|
|
||||||
]
|
|
||||||
|
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
stop_strings: List[str] = []
|
stop_strings: List[str] = []
|
||||||
extra_generation_config: Dict[str, Any] = {}
|
extra_generation_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
use_kv_cache: bool = True
|
use_kv_cache: bool = True
|
||||||
max_new_tokens: int = 4096
|
max_new_tokens: int = 4096
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_cache_folder(self) -> str:
|
def repo_cache_folder(self) -> str:
|
||||||
|
@ -5,18 +5,21 @@ from io import BytesIO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union, cast
|
from typing import List, Optional, Union, cast
|
||||||
|
|
||||||
import whisper # type: ignore
|
# import whisper # type: ignore
|
||||||
|
|
||||||
# import librosa
|
# import librosa
|
||||||
# import numpy as np
|
# import numpy as np
|
||||||
# import soundfile as sf # type: ignore
|
# import soundfile as sf # type: ignore
|
||||||
from docling_core.types.doc.labels import DocItemLabel
|
from docling_core.types.doc.labels import DocItemLabel
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
# from pydub import AudioSegment # type: ignore
|
|
||||||
# from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
|
|
||||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||||
from docling.backend.audio_backend import AudioBackend
|
from docling.backend.audio_backend import AudioBackend
|
||||||
|
|
||||||
|
# from pydub import AudioSegment # type: ignore
|
||||||
|
# from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
|
||||||
|
from docling.datamodel.accelerator_options import (
|
||||||
|
AcceleratorOptions,
|
||||||
|
)
|
||||||
from docling.datamodel.base_models import (
|
from docling.datamodel.base_models import (
|
||||||
ConversionStatus,
|
ConversionStatus,
|
||||||
)
|
)
|
||||||
@ -25,7 +28,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
AsrPipelineOptions,
|
AsrPipelineOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.pipeline_options_asr_model import (
|
from docling.datamodel.pipeline_options_asr_model import (
|
||||||
AsrResponseFormat,
|
# AsrResponseFormat,
|
||||||
InlineAsrOptions,
|
InlineAsrOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
@ -88,15 +91,41 @@ class _ConversationItem(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class _NativeWhisperModel:
|
class _NativeWhisperModel:
|
||||||
def __init__(self, model_name: str = "medium"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
artifacts_path: Optional[Path],
|
||||||
|
accelerator_options: AcceleratorOptions,
|
||||||
|
asr_options: InlineAsrOptions,
|
||||||
|
# model_name: str = "medium",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Transcriber using native Whisper.
|
Transcriber using native Whisper.
|
||||||
"""
|
"""
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
|
_log.info(f"artifacts-path: {artifacts_path}")
|
||||||
|
_log.info(f"accelerator_options: {accelerator_options}")
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
try:
|
||||||
|
import whisper # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"whisper is not installed. Please install it via `pip install openai-whisper`."
|
||||||
|
)
|
||||||
|
self.asr_options = asr_options
|
||||||
|
self.max_tokens = asr_options.max_new_tokens
|
||||||
|
self.temperature = asr_options.temperature
|
||||||
|
|
||||||
self.model = whisper.load_model(model_name)
|
|
||||||
|
self.model_name = asr_options.repo_id
|
||||||
|
_log.info(f"loading _NativeWhisperModel({self.model_name})")
|
||||||
|
self.model = whisper.load_model(self.model_name)
|
||||||
|
|
||||||
self.verbose = True
|
self.verbose = asr_options.verbose
|
||||||
self.word_timestamps = True
|
self.timestamps = asr_options.timestamps
|
||||||
|
self.word_timestamps = asr_options.word_timestamps
|
||||||
|
|
||||||
def run(self, conv_res: ConversionResult) -> ConversionResult:
|
def run(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
audio_path: Path = Path(conv_res.input.file).resolve()
|
audio_path: Path = Path(conv_res.input.file).resolve()
|
||||||
@ -126,15 +155,16 @@ class _NativeWhisperModel:
|
|||||||
item = _ConversationItem(
|
item = _ConversationItem(
|
||||||
start_time=_["start"], end_time=_["end"], text=_["text"], words=[]
|
start_time=_["start"], end_time=_["end"], text=_["text"], words=[]
|
||||||
)
|
)
|
||||||
item.words = []
|
if "words" in _ and self.word_timestamps:
|
||||||
for __ in _["words"]:
|
item.words = []
|
||||||
item.words.append(
|
for __ in _["words"]:
|
||||||
_ConversationWord(
|
item.words.append(
|
||||||
start_time=__["start"],
|
_ConversationWord(
|
||||||
end_time=__["end"],
|
start_time=__["start"],
|
||||||
text=__["word"],
|
end_time=__["end"],
|
||||||
|
text=__["word"],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
convo.append(item)
|
convo.append(item)
|
||||||
|
|
||||||
return convo
|
return convo
|
||||||
@ -159,8 +189,15 @@ class AsrPipeline(BasePipeline):
|
|||||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||||
)
|
)
|
||||||
|
|
||||||
# self._model = _WhisperModel()
|
if isinstance(self.pipeline_options.asr_options, InlineAsrOptions):
|
||||||
self._model = _NativeWhisperModel()
|
self._model = _NativeWhisperModel(
|
||||||
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
asr_options=pipeline_options.asr_options
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_log.error("")
|
||||||
|
|
||||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||||
status = ConversionStatus.SUCCESS
|
status = ConversionStatus.SUCCESS
|
||||||
@ -171,10 +208,9 @@ class AsrPipeline(BasePipeline):
|
|||||||
return AsrPipelineOptions()
|
return AsrPipelineOptions()
|
||||||
|
|
||||||
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
|
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
|
_log.info(f"start _build_document in AsrPipeline: {conv_res.input.file}")
|
||||||
with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT):
|
with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT):
|
||||||
_log.info(f"do something: {conv_res.input.file}")
|
|
||||||
self._model.run(conv_res=conv_res)
|
self._model.run(conv_res=conv_res)
|
||||||
_log.info(f"finished doing something: {conv_res.input.file}")
|
|
||||||
|
|
||||||
return conv_res
|
return conv_res
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user