WIP: got first transcription working

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-06-13 10:43:23 +02:00
parent 1d4008ac7c
commit 6dead88464
9 changed files with 352 additions and 67 deletions

View File

@ -0,0 +1,80 @@
import logging
import warnings
from io import BytesIO, StringIO
from pathlib import Path
from typing import Set, Union
from docling_core.types.doc import (
DoclingDocument,
DocumentOrigin,
)
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
class AudioBackend(DeclarativeDocumentBackend):
# content: StringIO
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
_log.info(f"path: {path_or_stream}")
# Load content
try:
if isinstance(self.path_or_stream, BytesIO):
_log.info(f"reading streaming: {self.path_or_stream}")
# self.content = StringIO(self.path_or_stream.getvalue().decode("utf-8"))
elif isinstance(self.path_or_stream, Path):
_log.info(f"reading file: {self.path_or_stream}")
# self.content = StringIO(self.path_or_stream.read())
self.valid = True
except Exception as e:
raise RuntimeError(
f"AudioBackend could not load document with hash {self.document_hash}"
) from e
return
def is_valid(self) -> bool:
return self.valid
@classmethod
def supports_pagination(cls) -> bool:
return False
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
self.path_or_stream = None
@classmethod
def supported_formats(cls) -> Set[InputFormat]:
return {InputFormat.AUDIO_WAV}
def convert(self) -> DoclingDocument:
"""
Parses the audio file into a structured document model.
"""
# Parse the CSV into a structured document model
origin = DocumentOrigin(
filename=self.file.name or "audio.wav",
mimetype="audio/wav",
binary_hash=self.document_hash,
)
_log.info(f"origin: {origin}")
doc = DoclingDocument(name=self.file.stem or "audio.wav", origin=origin)
if self.is_valid():
_log.error("time to get going ...")
else:
raise RuntimeError(
f"Cannot convert doc with {self.document_hash} because the audio backend failed to init."
)
return doc

View File

@ -23,6 +23,7 @@ from docling_core.utils.file import resolve_source_to_path
from pydantic import TypeAdapter
from rich.console import Console
from docling.backend.audio_backend import AudioBackend
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBackend
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
@ -59,7 +60,12 @@ from docling.datamodel.vlm_model_specs import (
SMOLDOCLING_TRANSFORMERS,
VlmModelType,
)
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
from docling.document_converter import (
AudioFormatOption,
DocumentConverter,
FormatOption,
PdfFormatOption,
)
from docling.models.factories import get_ocr_factory
from docling.pipeline.asr_pipeline import AsrPipeline
from docling.pipeline.vlm_pipeline import VlmPipeline
@ -543,40 +549,8 @@ def convert( # noqa: C901
pipeline_options: PaginatedPipelineOptions
format_options: Dict[InputFormat, FormatOption] = {}
if pipeline == ProcessingPipeline.VLM:
pipeline_options = VlmPipelineOptions(
enable_remote_services=enable_remote_services,
)
if vlm_model == VlmModelType.GRANITE_VISION:
pipeline_options.vlm_options = GRANITE_VISION_TRANSFORMERS
elif vlm_model == VlmModelType.GRANITE_VISION_OLLAMA:
pipeline_options.vlm_options = GRANITE_VISION_OLLAMA
elif vlm_model == VlmModelType.SMOLDOCLING:
pipeline_options.vlm_options = SMOLDOCLING_TRANSFORMERS
if sys.platform == "darwin":
try:
import mlx_vlm
pipeline_options.vlm_options = SMOLDOCLING_MLX
except ImportError:
_log.warning(
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)
pdf_format_option = PdfFormatOption(
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
)
format_options: Dict[InputFormat, FormatOption] = {
InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option,
}
elif pipeline == ProcessingPipeline.STANDARD:
if pipeline == ProcessingPipeline.STANDARD:
pipeline_options = PdfPipelineOptions(
allow_external_plugins=allow_external_plugins,
enable_remote_services=enable_remote_services,
@ -623,23 +597,59 @@ def convert( # noqa: C901
InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option,
}
elif pipeline == ProcessingPipeline.ASR:
audio_pipeline_options = AsrPipelineOptions(
# enable_remote_services=enable_remote_services,
artifacts_path = artifacts_path
elif pipeline == ProcessingPipeline.VLM:
pipeline_options = VlmPipelineOptions(
enable_remote_services=enable_remote_services,
)
audio_format_option = PdfFormatOption(
if vlm_model == VlmModelType.GRANITE_VISION:
pipeline_options.vlm_options = GRANITE_VISION_TRANSFORMERS
elif vlm_model == VlmModelType.GRANITE_VISION_OLLAMA:
pipeline_options.vlm_options = GRANITE_VISION_OLLAMA
elif vlm_model == VlmModelType.SMOLDOCLING:
pipeline_options.vlm_options = SMOLDOCLING_TRANSFORMERS
if sys.platform == "darwin":
try:
import mlx_vlm
pipeline_options.vlm_options = SMOLDOCLING_MLX
except ImportError:
_log.warning(
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)
pdf_format_option = PdfFormatOption(
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
)
format_options: Dict[InputFormat, FormatOption] = {
InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option,
}
elif pipeline == ProcessingPipeline.ASR:
pipeline_options = AsrPipelineOptions(
# enable_remote_services=enable_remote_services,
# artifacts_path = artifacts_path
)
if asr_model == AsrModelType.WHISPER_TINY:
pipeline_options.asr_options = WHISPER_TINY
else:
pipeline_options.asr_options = WHISPER_TINY
audio_format_option = AudioFormatOption(
pipeline_cls=AsrPipeline,
pipeline_options=audio_pipeline_options,
# backend = FIXME
pipeline_options=pipeline_options,
backend=AudioBackend,
)
format_options: Dict[InputFormat, FormatOption] = {
InputFormat.AUDIO_WAV: audio_format_option,
}
"""
if asr_model == AsrModelType.WHISPER_TINY:
pipeline_options.asr_options = WHISPER_TINY:
@ -656,6 +666,7 @@ def convert( # noqa: C901
start_time = time.time()
_log.info(f"paths: {input_doc_paths}")
conv_results = doc_converter.convert_all(
input_doc_paths, headers=parsed_headers, raises_on_error=abort_on_error
)

View File

@ -7,10 +7,10 @@ from pydantic import (
from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.pipeline_options_asr_model import (
AsrResponseFormat,
# ApiAsrOptions,
InferenceFramework,
InlineAsrOptions,
AsrResponseFormat,
TransformersModelType,
)
@ -20,8 +20,9 @@ _log = logging.getLogger(__name__)
WHISPER_TINY = InlineAsrOptions(
repo_id="openai/whisper-tiny",
inference_framework=InferenceFramework.TRANSFORMERS,
response_format = AsrResponseFormat.WHISPER,
response_format=AsrResponseFormat.WHISPER,
)
class AsrModelType(str, Enum):
WHISPER_TINY = "whisper_tiny"

View File

@ -322,7 +322,7 @@ class _DocumentConversionInput(BaseModel):
mime = mime or "text/plain"
formats = MimeTypeToFormat.get(mime, [])
print(formats)
if formats:
if len(formats) == 1 and mime not in ("text/plain"):
return formats[0]

View File

@ -5,7 +5,11 @@ from pydantic import AnyUrl, BaseModel
from typing_extensions import deprecated
from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.pipeline_options_vlm_model import InferenceFramework, TransformersModelType
from docling.datamodel.pipeline_options_vlm_model import (
InferenceFramework,
TransformersModelType,
)
class BaseAsrOptions(BaseModel):
kind: str
@ -15,7 +19,7 @@ class BaseAsrOptions(BaseModel):
class AsrResponseFormat(str, Enum):
WHISPER = "whisper"
class InlineAsrOptions(BaseAsrOptions):
kind: Literal["inline_model_options"] = "inline_model_options"
@ -46,5 +50,3 @@ class InlineAsrOptions(BaseAsrOptions):
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")

View File

@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict, model_validator, validate_call
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.asciidoc_backend import AsciiDocBackend
from docling.backend.audio_backend import AudioBackend
from docling.backend.csv_backend import CsvDocumentBackend
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
@ -41,6 +42,7 @@ from docling.datamodel.settings import (
settings,
)
from docling.exceptions import ConversionError
from docling.pipeline.asr_pipeline import AsrPipeline
from docling.pipeline.base_pipeline import BasePipeline
from docling.pipeline.simple_pipeline import SimplePipeline
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
@ -118,6 +120,11 @@ class PdfFormatOption(FormatOption):
backend: Type[AbstractDocumentBackend] = DoclingParseV4DocumentBackend
class AudioFormatOption(FormatOption):
pipeline_cls: Type = AsrPipeline
backend: Type[AbstractDocumentBackend] = AudioBackend
def _get_default_option(format: InputFormat) -> FormatOption:
format_to_default_options = {
InputFormat.CSV: FormatOption(
@ -156,6 +163,9 @@ def _get_default_option(format: InputFormat) -> FormatOption:
InputFormat.JSON_DOCLING: FormatOption(
pipeline_cls=SimplePipeline, backend=DoclingJSONBackend
),
InputFormat.AUDIO_WAV: FormatOption(
pipeline_cls=AsrPipeline, backend=AudioBackend
),
}
if (options := format_to_default_options.get(format)) is not None:
return options

View File

@ -1,30 +1,179 @@
import logging
import os
import re
from io import BytesIO
from pathlib import Path
from typing import List, Optional, Union, cast
from docling.backend.abstract_backend import AbstractDocumentBackend
import soundfile as sf
from docling_core.types.doc.labels import DocItemLabel
from pydantic import BaseModel, Field, validator
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.audio_backend import AudioBackend
from docling.datamodel.base_models import (
ConversionStatus,
)
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import (
AsrPipelineOptions,
)
from docling.datamodel.pipeline_options_asr_model import (
AsrResponseFormat,
InlineAsrOptions,
)
from docling.datamodel.pipeline_options_vlm_model import (
InferenceFramework,
)
from docling.datamodel.pipeline_options_asr_model import (
InlineAsrOptions,
AsrResponseFormat,
)
from docling.datamodel.settings import settings
from docling.pipeline.base_pipeline import BasePipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.datamodel.document import ConversionResult, InputDocument
_log = logging.getLogger(__name__)
class ConversationEntry(BaseModel):
text: str
start_time: float = Field(
..., ge=0, description="Start time in seconds from video start"
)
end_time: float = Field(
..., ge=0, description="End time in seconds from video start"
)
speaker_id: int = Field(..., ge=0, description="Numeric speaker identifier")
speaker: Optional[str] = Field(
None, description="Speaker name, defaults to speaker-{speaker_id}"
)
@validator("end_time")
def end_time_must_be_after_start(cls, v, values):
if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time must be greater than start_time")
return v
@validator("speaker", always=True)
def set_default_speaker_name(cls, v, values):
if v is None and "speaker_id" in values:
return f"speaker-{values['speaker_id']}"
return v
def __lt__(self, other):
if not isinstance(other, ConversationEntry):
return NotImplemented
return self.start_time < other.start_time
def __eq__(self, other):
if not isinstance(other, ConversationEntry):
return NotImplemented
return self.start_time == other.start_time
def to_string(self) -> str:
"""Format the conversation entry as a string"""
return f"[time: {self.start_time}-{self.end_time}] [speaker:{self.speaker}] {self.text}"
class _WhisperModel:
def __init__(self):
_log.info("initialisation `_WhisperModel`")
from transformers import WhisperForConditionalGeneration, WhisperProcessor
self.model_repo = "openai/whisper-tiny"
self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
self.model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny"
)
_log.info(f"model is loaded: {self.model_repo}")
def run(self, conv_res: ConversionResult):
# fpath = Path(conv_res.input.file)
# _log.info(f"`_WhisperModel::run: {conv_res}`")
_log.info(f"`_WhisperModel::run: {conv_res.input}`")
_log.info(f"`_WhisperModel::run: {conv_res.input.file}`")
if os.path.exists(str(conv_res.input.file)):
print("file exists")
else:
print("file does not exist")
#
_log.info(f"sampling-rate: {self.processor.feature_extractor.sampling_rate}")
try:
fpath = conv_res.input.file
# array, sampling_rate = sf.read(fpath)#, samplerate=processor.feature_extractor.sampling_rate)
array, sampling_rate = sf.read(
fpath
) # , samplerate=self.processor.feature_extractor.sampling_rate)
_log.info(
f"read the file .. (sampling-rate: {sampling_rate}, array: {array.shape})"
)
processed_input = self.processor(
array,
sampling_rate=self.processor.feature_extractor.sampling_rate, # sampling_rate,
return_tensors="pt",
)
print(processed_input)
# pre-process to get the input features
input_features = self.processor(
array, sampling_rate=sampling_rate, return_tensors="pt"
).input_features
_log.info(f"got input-features: {input_features.shape}")
# generate token ids by running model forward sequentially
predicted_ids = self.model.generate(
input_features, max_new_tokens=256, return_timestamps=True
)
_log.info("ran model ..")
"""
transcription = self.processor.batch_decode(predicted_ids,
skip_special_tokens=False,
decode_with_timestamps=True)
_log.info("decoded output ..")
print(f"Transcription: {transcription}")
"""
conversation = []
print("Timestamp info:")
for pidi, pid in enumerate(predicted_ids):
# timestamps = processor.tokenizer.decode(pid, decode_with_timestamps=True)
timestamps = self.processor.tokenizer.decode(pid, output_offsets=True)
print(f"Predicted id [{pidi}]: {timestamps['text']}")
for offset in timestamps["offsets"]:
print(f" => {offset['timestamp']}: {offset['text']}")
item = ConversationEntry(
text=offset["text"],
speaker_id=pidi,
start_time=offset["timestamp"][0],
end_time=offset["timestamp"][1],
)
conv_res.document.add_text(
label=DocItemLabel.TEXT, text=item.to_string()
)
conv_res.status = ConversionStatus.SUCCESS
print("document: \n\n", conv_res.document.export_to_markdown())
except Exception as exc:
conv_res.status = ConversionStatus.FAILED
print(exc)
return conv_res
class AsrPipeline(BasePipeline):
def __init__(self, pipeline_options: AsrPipelineOptions):
super().__init__(pipeline_options)
@ -44,19 +193,24 @@ class AsrPipeline(BasePipeline):
"When defined, it must point to a folder containing all models required by the pipeline."
)
self._model = _WhisperModel()
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
status = ConversionStatus.SUCCESS
return status
@classmethod
def get_default_options(cls) -> AsrPipelineOptions:
return AsrPipelineOptions()
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
total_elapsed_time = 0.0
with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT):
print("do something")
_log.info(f"do something: {conv_res.input.file}")
self._model.run(conv_res=conv_res)
_log.info(f"finished doing something: {conv_res.input.file}")
return conv_res
"""
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
status = ConversionStatus()
return status
"""
@classmethod
@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend):
return True
return isinstance(backend, AudioBackend)

View File

@ -99,6 +99,9 @@ rapidocr = [
# 'onnxruntime (>=1.7.0,<2.0.0) ; python_version >= "3.10"',
# 'onnxruntime (>=1.7.0,<1.20.0) ; python_version < "3.10"',
]
asr = [
"soundfile>=0.13.1",
]
[dependency-groups]
dev = [

26
uv.lock generated
View File

@ -844,6 +844,9 @@ dependencies = [
]
[package.optional-dependencies]
asr = [
{ name = "soundfile" },
]
ocrmac = [
{ name = "ocrmac", marker = "sys_platform == 'darwin'" },
]
@ -932,12 +935,13 @@ requires-dist = [
{ name = "requests", specifier = ">=2.32.2,<3.0.0" },
{ name = "rtree", specifier = ">=1.3.0,<2.0.0" },
{ name = "scipy", specifier = ">=1.6.0,<2.0.0" },
{ name = "soundfile", marker = "extra == 'asr'", specifier = ">=0.13.1" },
{ name = "tesserocr", marker = "extra == 'tesserocr'", specifier = ">=2.7.1,<3.0.0" },
{ name = "tqdm", specifier = ">=4.65.0,<5.0.0" },
{ name = "transformers", marker = "extra == 'vlm'", specifier = ">=4.46.0,<5.0.0" },
{ name = "typer", specifier = ">=0.12.5,<0.17.0" },
]
provides-extras = ["tesserocr", "ocrmac", "vlm", "rapidocr"]
provides-extras = ["tesserocr", "ocrmac", "vlm", "rapidocr", "asr"]
[package.metadata.requires-dev]
constraints = [
@ -5764,6 +5768,26 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
]
[[package]]
name = "soundfile"
version = "0.13.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cffi" },
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e1/41/9b873a8c055582859b239be17902a85339bec6a30ad162f98c9b0288a2cc/soundfile-0.13.1.tar.gz", hash = "sha256:b2c68dab1e30297317080a5b43df57e302584c49e2942defdde0acccc53f0e5b", size = 46156, upload-time = "2025-01-25T09:17:04.831Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/64/28/e2a36573ccbcf3d57c00626a21fe51989380636e821b341d36ccca0c1c3a/soundfile-0.13.1-py2.py3-none-any.whl", hash = "sha256:a23c717560da2cf4c7b5ae1142514e0fd82d6bbd9dfc93a50423447142f2c445", size = 25751, upload-time = "2025-01-25T09:16:44.235Z" },
{ url = "https://files.pythonhosted.org/packages/ea/ab/73e97a5b3cc46bba7ff8650a1504348fa1863a6f9d57d7001c6b67c5f20e/soundfile-0.13.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:82dc664d19831933fe59adad199bf3945ad06d84bc111a5b4c0d3089a5b9ec33", size = 1142250, upload-time = "2025-01-25T09:16:47.583Z" },
{ url = "https://files.pythonhosted.org/packages/a0/e5/58fd1a8d7b26fc113af244f966ee3aecf03cb9293cb935daaddc1e455e18/soundfile-0.13.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:743f12c12c4054921e15736c6be09ac26b3b3d603aef6fd69f9dde68748f2593", size = 1101406, upload-time = "2025-01-25T09:16:49.662Z" },
{ url = "https://files.pythonhosted.org/packages/58/ae/c0e4a53d77cf6e9a04179535766b3321b0b9ced5f70522e4caf9329f0046/soundfile-0.13.1-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9c9e855f5a4d06ce4213f31918653ab7de0c5a8d8107cd2427e44b42df547deb", size = 1235729, upload-time = "2025-01-25T09:16:53.018Z" },
{ url = "https://files.pythonhosted.org/packages/57/5e/70bdd9579b35003a489fc850b5047beeda26328053ebadc1fb60f320f7db/soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:03267c4e493315294834a0870f31dbb3b28a95561b80b134f0bd3cf2d5f0e618", size = 1313646, upload-time = "2025-01-25T09:16:54.872Z" },
{ url = "https://files.pythonhosted.org/packages/fe/df/8c11dc4dfceda14e3003bb81a0d0edcaaf0796dd7b4f826ea3e532146bba/soundfile-0.13.1-py2.py3-none-win32.whl", hash = "sha256:c734564fab7c5ddf8e9be5bf70bab68042cd17e9c214c06e365e20d64f9a69d5", size = 899881, upload-time = "2025-01-25T09:16:56.663Z" },
{ url = "https://files.pythonhosted.org/packages/14/e9/6b761de83277f2f02ded7e7ea6f07828ec78e4b229b80e4ca55dd205b9dc/soundfile-0.13.1-py2.py3-none-win_amd64.whl", hash = "sha256:1e70a05a0626524a69e9f0f4dd2ec174b4e9567f4d8b6c11d38b5c289be36ee9", size = 1019162, upload-time = "2025-01-25T09:16:59.573Z" },
]
[[package]]
name = "soupsieve"
version = "2.7"