work in progress: slowly adding ASR pipeline and its derivatives

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-12 07:33:38 +02:00
parent 776e7ecf9a
commit 32ad65cb9f
9 changed files with 241 additions and 6 deletions

View File

@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from io import BytesIO
from pathlib import Path
from typing import Optional, Set, Union
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
class WavDocumentBackend(AbstractDocumentBackend):
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
def is_valid(self) -> bool:
return True
@classmethod
def supports_pagination(cls) -> bool:
return False
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
self.path_or_stream = None
@classmethod
def supported_formats(cls) -> set[InputFormat]:
return {InputFormat.WAV}

View File

@ -570,7 +570,18 @@ def convert( # noqa: C901
pdf_format_option = PdfFormatOption(
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
)
elif pipeline == PdfPipeline.ASR:
pipeline_options = AsrPipelineOptions()
pipeline_options.asr_options = asr_nemo_conversion_options
asr_format_option = AsrFormatOption(
pipeline_cls=AsrPipeline, pipeline_options=pipeline_options
)
else:
_log.error(f"Did not find the correct pipeline: {pipeline}")
if artifacts_path is not None:
pipeline_options.artifacts_path = artifacts_path

View File

@ -34,6 +34,7 @@ class ConversionStatus(str, Enum):
class InputFormat(str, Enum):
"""A document format supported by document backend parsers."""
# Documents
DOCX = "docx"
PPTX = "pptx"
HTML = "html"
@ -47,6 +48,8 @@ class InputFormat(str, Enum):
XML_JATS = "xml_jats"
JSON_DOCLING = "json_docling"
# Audio
WAV = "wav"
class OutputFormat(str, Enum):
MARKDOWN = "md"
@ -70,6 +73,8 @@ FormatToExtensions: Dict[InputFormat, List[str]] = {
InputFormat.XLSX: ["xlsx"],
InputFormat.XML_USPTO: ["xml", "txt"],
InputFormat.JSON_DOCLING: ["json"],
# Audio
InputFormat.WAV: ["wav"],
}
FormatToMimeType: Dict[InputFormat, List[str]] = {
@ -100,6 +105,9 @@ FormatToMimeType: Dict[InputFormat, List[str]] = {
],
InputFormat.XML_USPTO: ["application/xml", "text/plain"],
InputFormat.JSON_DOCLING: ["application/json"],
# Audio
InputFormat.WAV: ["audio/wav", "audio/x-wav"],
}
MimeTypeToFormat: dict[str, list[InputFormat]] = {
@ -157,6 +165,9 @@ class LayoutPrediction(BaseModel):
class VlmPrediction(BaseModel):
text: str = ""
class AsrPrediction(BaseModel):
text: str = ""
class ContainerElement(
BasePageElement

View File

@ -278,6 +278,8 @@ class _DocumentConversionInput(BaseModel):
if isinstance(obj, Path):
mime = filetype.guess_mime(str(obj))
print(mime)
if mime is None:
ext = obj.suffix[1:]
mime = _DocumentConversionInput._mime_from_extension(ext)
@ -290,8 +292,8 @@ class _DocumentConversionInput(BaseModel):
elif obj.suffixes[-1].lower() == ".docx":
mime = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
elif obj.suffixes[-1].lower() == ".pptx":
mime = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
mime = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
elif isinstance(obj, DocumentStream):
content = obj.stream.read(8192)
obj.stream.seek(0)
@ -311,10 +313,11 @@ class _DocumentConversionInput(BaseModel):
mime = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
elif objname.endswith(".pptx"):
mime = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
mime = mime or _DocumentConversionInput._detect_html_xhtml(content)
mime = mime or _DocumentConversionInput._detect_csv(content)
mime = mime or "text/plain"
formats = MimeTypeToFormat.get(mime, [])
if formats:
if len(formats) == 1 and mime not in ("text/plain"):
@ -363,6 +366,8 @@ class _DocumentConversionInput(BaseModel):
@staticmethod
def _mime_from_extension(ext):
print("ext: ", ext)
mime = None
if ext in FormatToExtensions[InputFormat.ASCIIDOC]:
mime = FormatToMimeType[InputFormat.ASCIIDOC][0]
@ -376,6 +381,8 @@ class _DocumentConversionInput(BaseModel):
mime = FormatToMimeType[InputFormat.JSON_DOCLING][0]
elif ext in FormatToExtensions[InputFormat.PDF]:
mime = FormatToMimeType[InputFormat.PDF][0]
elif ext in FormatToExtensions[InputFormat.WAV]:
mime = FormatToMimeType[InputFormat.WAV][0]
return mime
@staticmethod

View File

@ -257,6 +257,9 @@ class BaseVlmOptions(BaseModel):
kind: str
prompt: str
class BaseAsrOptions(BaseModel):
kind: str
prompt: str
class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
@ -268,6 +271,8 @@ class InferenceFramework(str, Enum):
TRANSFORMERS = "transformers"
OPENAI = "openai"
# Audio
ASR_NEMO = "asr_nemo"
class HuggingFaceVlmOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"
@ -284,6 +289,20 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
class HuggingFaceAsrOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"
repo_id: str
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
response_format: ResponseFormat
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
class ApiVlmOptions(BaseVlmOptions):
kind: Literal["api_model_options"] = "api_model_options"
@ -330,6 +349,13 @@ granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
response_format=ResponseFormat.MARKDOWN,
)
asr_nemo_conversion_options = HuggingFaceAsrOptions(
repo_id="nvidia/parakeet-tdt-0.6b-v2",
prompt="Convert this page to docling.",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.ASR_NEMO,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
@ -389,7 +415,11 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
smoldocling_vlm_conversion_options
)
class AsrPipelineOptions(PaginatedPipelineOptions):
asr_options: Union[HuggingFaceAsrOptions] = (
asr_nemo_conversion_options
)
class PdfPipelineOptions(PaginatedPipelineOptions):
"""Options for the PDF pipeline."""

View File

@ -21,6 +21,7 @@ from docling.backend.mspowerpoint_backend import MsPowerpointDocumentBackend
from docling.backend.msword_backend import MsWordDocumentBackend
from docling.backend.xml.jats_backend import JatsDocumentBackend
from docling.backend.xml.uspto_backend import PatentUsptoDocumentBackend
from docling.backend.wav_backend import WavDocumentBackend
from docling.datamodel.base_models import (
ConversionStatus,
DoclingComponentType,
@ -33,7 +34,7 @@ from docling.datamodel.document import (
InputDocument,
_DocumentConversionInput,
)
from docling.datamodel.pipeline_options import PipelineOptions
from docling.datamodel.pipeline_options import PipelineOptions, AsrPipelineOptions
from docling.datamodel.settings import (
DEFAULT_PAGE_RANGE,
DocumentLimits,
@ -44,6 +45,7 @@ from docling.exceptions import ConversionError
from docling.pipeline.base_pipeline import BasePipeline
from docling.pipeline.simple_pipeline import SimplePipeline
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
from docling.pipeline.asr_pipeline import AsrPipeline
from docling.utils.utils import chunkify
_log = logging.getLogger(__name__)
@ -117,7 +119,9 @@ class PdfFormatOption(FormatOption):
pipeline_cls: Type = StandardPdfPipeline
backend: Type[AbstractDocumentBackend] = DoclingParseV4DocumentBackend
class AsrFormatOption(FormatOption):
pipeline_cls: Type = AsrPipeline
def _get_default_option(format: InputFormat) -> FormatOption:
format_to_default_options = {
InputFormat.CSV: FormatOption(
@ -156,6 +160,9 @@ def _get_default_option(format: InputFormat) -> FormatOption:
InputFormat.JSON_DOCLING: FormatOption(
pipeline_cls=SimplePipeline, backend=DoclingJSONBackend
),
InputFormat.WAV: FormatOption(
pipeline_cls=AsrPipeline, backend=WavDocumentBackend
),
}
if (options := format_to_default_options.get(format)) is not None:
return options
@ -292,7 +299,10 @@ class DocumentConverter:
"""Retrieve or initialize a pipeline, reusing instances based on class and options."""
fopt = self.format_to_options.get(doc_format)
print(self.format_to_options)
if fopt is None or fopt.pipeline_options is None:
_log.warning(f"fopt ({fopt}) or its options are None for {doc_format}")
return None
pipeline_class = fopt.pipeline_cls
@ -345,6 +355,7 @@ class DocumentConverter:
) -> ConversionResult:
if in_doc.valid:
pipeline = self._get_pipeline(in_doc.format)
print(f"_execute_pipeline: {pipeline}")
if pipeline is not None:
conv_res = pipeline.execute(in_doc, raises_on_error=raises_on_error)
else:

View File

View File

@ -0,0 +1,51 @@
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from docling.datamodel.base_models import AsrPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
HuggingFaceAsrOptions,
)
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class AsrNemoModel(BasePageModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
asr_options: HuggingFaceAsrOptions,
):
self.enabled = enabled
self.asr_options = asr_options
if self.enabled:
import nemo.collections.asr as nemo_asr
device = decide_device(accelerator_options.device)
self.device = device
_log.debug(f"Available device for HuggingFace ASR: {device}")
repo_cache_folder = asr_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.asr_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")

View File

@ -0,0 +1,82 @@
import logging
from io import BytesIO
from pathlib import Path
from typing import List, Optional, Union, cast
from docling.backend.abstract_backend import (
AbstractDocumentBackend,
DeclarativeDocumentBackend,
)
from docling.datamodel.base_models import ConversionStatus
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import PipelineOptions
from docling.pipeline.base_pipeline import BasePipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.datamodel.pipeline_options import (
HuggingFaceAsrOptions,
InferenceFramework,
ResponseFormat,
AsrPipelineOptions,
)
from docling.models.hf_asr_models.asr_nemo import AsrNemoModel
_log = logging.getLogger(__name__)
class AsrPipeline(BasePipeline):
def __init__(self, pipeline_options: AsrPipelineOptions):
super().__init__(pipeline_options)
self.keep_backend = True
self.pipeline_options: AsrPipelineOptions
artifacts_path: Optional[Path] = None
if pipeline_options.artifacts_path is not None:
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
elif settings.artifacts_path is not None:
artifacts_path = Path(settings.artifacts_path).expanduser()
if artifacts_path is not None and not artifacts_path.is_dir():
raise RuntimeError(
f"The value of {artifacts_path=} is not valid. "
"When defined, it must point to a folder containing all models required by the pipeline."
)
if isinstance(self.pipeline_options.asr_options, HuggingFaceAsrOptions):
asr_options = cast(HuggingFaceAsrOptions, self.pipeline_options.asr_options)
if asr_options.inference_framework == InferenceFramework.ASR_NENO:
self.build_pipe = [
AsrNemoModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
asr_options=asr_options,
),
]
else:
_log.error(f"{asr_options.inference_framework} is not supported")
else:
_log.error(f"ASR is not supported")
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
pass
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
return conv_res
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
pass
def _unload(self, conv_res: ConversionResult):
pass
@classmethod
def get_default_options(cls) -> PipelineOptions:
pass
@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend):
pass