feat: add a backend parser for WebVTT files (#2288)

* feat: add a backend parser for WebVTT files

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* docs: update README with VTT support

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* docs: add description to supported formats

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* chore: upgrade docling-core to unescape WebVTT in markdown

Pin the new release of docling-core 2.48.2.
Do not escape HTML reserved characters when exporting WebVTT documents to markdown.

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* test: add missing copyright notice

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

---------

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>
This commit is contained in:
Cesar Berrospi Ramis
2025-09-22 15:24:34 +02:00
committed by GitHub
parent b5628f1227
commit 46efaaefee
23 changed files with 3969 additions and 34 deletions

View File

@@ -0,0 +1,572 @@
import logging
import re
from io import BytesIO
from pathlib import Path
from typing import Annotated, ClassVar, Literal, Optional, Union, cast
from docling_core.types.doc import (
ContentLayer,
DocItemLabel,
DoclingDocument,
DocumentOrigin,
Formatting,
GroupLabel,
NodeItem,
)
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic.types import StringConstraints
from typing_extensions import Self, override
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
class _WebVTTTimestamp(BaseModel):
"""Model representing a WebVTT timestamp.
A WebVTT timestamp is always interpreted relative to the current playback position
of the media data that the WebVTT file is to be synchronized with.
"""
model_config = ConfigDict(regex_engine="python-re")
raw: Annotated[
str,
Field(
description="A representation of the WebVTT Timestamp as a single string"
),
]
_pattern: ClassVar[re.Pattern] = re.compile(
r"^(?:(\d{2,}):)?([0-5]\d):([0-5]\d)\.(\d{3})$"
)
_hours: int
_minutes: int
_seconds: int
_millis: int
@model_validator(mode="after")
def validate_raw(self) -> Self:
m = self._pattern.match(self.raw)
if not m:
raise ValueError(f"Invalid WebVTT timestamp format: {self.raw}")
self._hours = int(m.group(1)) if m.group(1) else 0
self._minutes = int(m.group(2))
self._seconds = int(m.group(3))
self._millis = int(m.group(4))
if self._minutes < 0 or self._minutes > 59:
raise ValueError("Minutes must be between 0 and 59")
if self._seconds < 0 or self._seconds > 59:
raise ValueError("Seconds must be between 0 and 59")
return self
@property
def seconds(self) -> float:
"""A representation of the WebVTT Timestamp in seconds"""
return (
self._hours * 3600
+ self._minutes * 60
+ self._seconds
+ self._millis / 1000.0
)
@override
def __str__(self) -> str:
return self.raw
_WebVTTCueIdentifier = Annotated[
str, StringConstraints(strict=True, pattern=r"^(?!.*-->)[^\n\r]+$")
]
class _WebVTTCueTimings(BaseModel):
"""Model representating WebVTT cue timings."""
start: Annotated[
_WebVTTTimestamp, Field(description="Start time offset of the cue")
]
end: Annotated[_WebVTTTimestamp, Field(description="End time offset of the cue")]
@model_validator(mode="after")
def check_order(self) -> Self:
if self.start and self.end:
if self.end.seconds <= self.start.seconds:
raise ValueError("End timestamp must be greater than start timestamp")
return self
@override
def __str__(self):
return f"{self.start} --> {self.end}"
class _WebVTTCueTextSpan(BaseModel):
"""Model representing a WebVTT cue text span."""
text: str
span_type: Literal["text"] = "text"
@field_validator("text", mode="after")
@classmethod
def validate_text(cls, value: str) -> str:
if any(ch in value for ch in {"\n", "\r", "&", "<"}):
raise ValueError("Cue text span contains invalid characters")
if len(value) == 0:
raise ValueError("Cue text span cannot be empty")
return value
@override
def __str__(self):
return self.text
class _WebVTTCueVoiceSpan(BaseModel):
"""Model representing a WebVTT cue voice span."""
annotation: Annotated[
str,
Field(
description=(
"Cue span start tag annotation text representing the name of thevoice"
)
),
]
classes: Annotated[
list[str],
Field(description="List of classes representing the cue span's significance"),
] = []
components: Annotated[
list["_WebVTTCueComponent"],
Field(description="The components representing the cue internal text"),
] = []
span_type: Literal["v"] = "v"
@field_validator("annotation", mode="after")
@classmethod
def validate_annotation(cls, value: str) -> str:
if any(ch in value for ch in {"\n", "\r", "&", ">"}):
raise ValueError(
"Cue span start tag annotation contains invalid characters"
)
if not value:
raise ValueError("Cue text span cannot be empty")
return value
@field_validator("classes", mode="after")
@classmethod
def validate_classes(cls, value: list[str]) -> list[str]:
for item in value:
if any(ch in item for ch in {"\t", "\n", "\r", " ", "&", "<", ">", "."}):
raise ValueError(
"A cue span start tag class contains invalid characters"
)
if not item:
raise ValueError("Cue span start tag classes cannot be empty")
return value
@override
def __str__(self):
tag = f"v.{'.'.join(self.classes)}" if self.classes else "v"
inner = "".join(str(span) for span in self.components)
return f"<{tag} {self.annotation}>{inner}</v>"
class _WebVTTCueClassSpan(BaseModel):
span_type: Literal["c"] = "c"
components: list["_WebVTTCueComponent"]
@override
def __str__(self):
inner = "".join(str(span) for span in self.components)
return f"<c>{inner}</c>"
class _WebVTTCueItalicSpan(BaseModel):
span_type: Literal["i"] = "i"
components: list["_WebVTTCueComponent"]
@override
def __str__(self):
inner = "".join(str(span) for span in self.components)
return f"<i>{inner}</i>"
class _WebVTTCueBoldSpan(BaseModel):
span_type: Literal["b"] = "b"
components: list["_WebVTTCueComponent"]
@override
def __str__(self):
inner = "".join(str(span) for span in self.components)
return f"<b>{inner}</b>"
class _WebVTTCueUnderlineSpan(BaseModel):
span_type: Literal["u"] = "u"
components: list["_WebVTTCueComponent"]
@override
def __str__(self):
inner = "".join(str(span) for span in self.components)
return f"<u>{inner}</u>"
_WebVTTCueComponent = Annotated[
Union[
_WebVTTCueTextSpan,
_WebVTTCueClassSpan,
_WebVTTCueItalicSpan,
_WebVTTCueBoldSpan,
_WebVTTCueUnderlineSpan,
_WebVTTCueVoiceSpan,
],
Field(discriminator="span_type", description="The WebVTT cue component"),
]
class _WebVTTCueBlock(BaseModel):
"""Model representing a WebVTT cue block.
The optional WebVTT cue settings list is not supported.
The cue payload is limited to the following spans: text, class, italic, bold,
underline, and voice.
"""
model_config = ConfigDict(regex_engine="python-re")
identifier: Optional[_WebVTTCueIdentifier] = Field(
None, description="The WebVTT cue identifier"
)
timings: Annotated[_WebVTTCueTimings, Field(description="The WebVTT cue timings")]
payload: Annotated[list[_WebVTTCueComponent], Field(description="The cue payload")]
_pattern_block: ClassVar[re.Pattern] = re.compile(
r"<(/?)(i|b|c|u|v(?:\.[^\t\n\r &<>.]+)*)(?:\s+([^>]*))?>"
)
_pattern_voice_tag: ClassVar[re.Pattern] = re.compile(
r"^<v(?P<class>\.[^\t\n\r &<>]+)?" # zero or more classes
r"[ \t]+(?P<annotation>[^\n\r&>]+)>" # required space and annotation
)
@field_validator("payload", mode="after")
@classmethod
def validate_payload(cls, payload):
for voice in payload:
if "-->" in str(voice):
raise ValueError("Cue payload must not contain '-->'")
return payload
@classmethod
def parse(cls, raw: str) -> "_WebVTTCueBlock":
lines = raw.strip().splitlines()
if not lines:
raise ValueError("Cue block must have at least one line")
identifier: Optional[_WebVTTCueIdentifier] = None
timing_line = lines[0]
if "-->" not in timing_line and len(lines) > 1:
identifier = timing_line
timing_line = lines[1]
cue_lines = lines[2:]
else:
cue_lines = lines[1:]
if "-->" not in timing_line:
raise ValueError("Cue block must contain WebVTT cue timings")
start, end = [t.strip() for t in timing_line.split("-->")]
end = re.split(" |\t", end)[0] # ignore the cue settings list
timings: _WebVTTCueTimings = _WebVTTCueTimings(
start=_WebVTTTimestamp(raw=start), end=_WebVTTTimestamp(raw=end)
)
cue_text = " ".join(cue_lines).strip()
if cue_text.startswith("<v") and "</v>" not in cue_text:
# adding close tag for cue voice spans without end tag
cue_text += "</v>"
stack: list[list[_WebVTTCueComponent]] = [[]]
tag_stack: list[Union[str, tuple]] = []
pos = 0
matches = list(cls._pattern_block.finditer(cue_text))
i = 0
while i < len(matches):
match = matches[i]
if match.start() > pos:
stack[-1].append(_WebVTTCueTextSpan(text=cue_text[pos : match.start()]))
tag = match.group(0)
if tag.startswith(("<i>", "<b>", "<u>", "<c>")):
tag_type = tag[1:2]
tag_stack.append(tag_type)
stack.append([])
elif tag == "</i>":
children = stack.pop()
stack[-1].append(_WebVTTCueItalicSpan(components=children))
tag_stack.pop()
elif tag == "</b>":
children = stack.pop()
stack[-1].append(_WebVTTCueBoldSpan(components=children))
tag_stack.pop()
elif tag == "</u>":
children = stack.pop()
stack[-1].append(_WebVTTCueUnderlineSpan(components=children))
tag_stack.pop()
elif tag == "</c>":
children = stack.pop()
stack[-1].append(_WebVTTCueClassSpan(components=children))
tag_stack.pop()
elif tag.startswith("<v"):
tag_stack.append(("v", tag))
stack.append([])
elif tag.startswith("</v"):
children = stack.pop() if stack else []
if (
tag_stack
and isinstance(tag_stack[-1], tuple)
and tag_stack[-1][0] == "v"
):
_, voice = cast(tuple, tag_stack.pop())
voice_match = cls._pattern_voice_tag.match(voice)
if voice_match:
class_string = voice_match.group("class")
annotation = voice_match.group("annotation")
if annotation:
classes: list[str] = []
if class_string:
classes = [c for c in class_string.split(".") if c]
stack[-1].append(
_WebVTTCueVoiceSpan(
annotation=annotation.strip(),
classes=classes,
components=children,
)
)
pos = match.end()
i += 1
if pos < len(cue_text):
stack[-1].append(_WebVTTCueTextSpan(text=cue_text[pos:]))
return cls(
identifier=identifier,
timings=timings,
payload=stack[0],
)
def __str__(self):
parts = []
if self.identifier:
parts.append(f"{self.identifier}\n")
timings_line = str(self.timings)
parts.append(timings_line + "\n")
for idx, span in enumerate(self.payload):
if idx == 0 and len(self.payload) == 1 and span.span_type == "v":
# the end tag may be omitted for brevity
parts.append(str(span).removesuffix("</v>"))
else:
parts.append(str(span))
return "".join(parts)
class _WebVTTFile(BaseModel):
"""A model representing a WebVTT file."""
cue_blocks: list[_WebVTTCueBlock]
@staticmethod
def verify_signature(content: str) -> bool:
if not content:
return False
elif len(content) == 6:
return content == "WEBVTT"
elif len(content) > 6 and content.startswith("WEBVTT"):
return content[6] in (" ", "\t", "\n")
else:
return False
@classmethod
def parse(cls, raw: str) -> "_WebVTTFile":
# Normalize newlines to LF
raw = raw.replace("\r\n", "\n").replace("\r", "\n")
# Check WebVTT signature
if not cls.verify_signature(raw):
raise ValueError("Invalid WebVTT file signature")
# Strip "WEBVTT" header line
lines = raw.split("\n", 1)
body = lines[1] if len(lines) > 1 else ""
# Remove NOTE/STYLE/REGION blocks
body = re.sub(r"^(NOTE[^\n]*\n(?:.+\n)*?)\n", "", body, flags=re.MULTILINE)
body = re.sub(r"^(STYLE|REGION)(?:.+\n)*?\n", "", body, flags=re.MULTILINE)
# Split into cue blocks
raw_blocks = re.split(r"\n\s*\n", body.strip())
cues: list[_WebVTTCueBlock] = []
for block in raw_blocks:
try:
cues.append(_WebVTTCueBlock.parse(block))
except ValueError as e:
_log.warning(f"Failed to parse cue block:\n{block}\n{e}")
return cls(cue_blocks=cues)
def __iter__(self):
return iter(self.cue_blocks)
def __getitem__(self, idx):
return self.cue_blocks[idx]
def __len__(self):
return len(self.cue_blocks)
class WebVTTDocumentBackend(DeclarativeDocumentBackend):
"""Declarative backend for WebVTT (.vtt) files.
This parser reads the content of a WebVTT file and converts
it to a DoclingDocument, following the W3C specs on https://www.w3.org/TR/webvtt1
Each cue becomes a TextItem and the items are appended to the
document body by the cue's start time.
"""
@override
def __init__(self, in_doc: InputDocument, path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
self.content: str = ""
try:
if isinstance(self.path_or_stream, BytesIO):
self.content = self.path_or_stream.getvalue().decode("utf-8")
if isinstance(self.path_or_stream, Path):
with open(self.path_or_stream, encoding="utf-8") as f:
self.content = f.read()
except Exception as e:
raise RuntimeError(
"Could not initialize the WebVTT backend for file with hash "
f"{self.document_hash}."
) from e
@override
def is_valid(self) -> bool:
return _WebVTTFile.verify_signature(self.content)
@classmethod
@override
def supports_pagination(cls) -> bool:
return False
@override
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
self.path_or_stream = None
@classmethod
@override
def supported_formats(cls) -> set[InputFormat]:
return {InputFormat.VTT}
@staticmethod
def _add_text_from_component(
doc: DoclingDocument, item: _WebVTTCueComponent, parent: Optional[NodeItem]
) -> None:
"""Adds a TextItem to a document by extracting text from a cue span component.
TODO: address nesting
"""
formatting = Formatting()
text = ""
if isinstance(item, _WebVTTCueItalicSpan):
formatting.italic = True
elif isinstance(item, _WebVTTCueBoldSpan):
formatting.bold = True
elif isinstance(item, _WebVTTCueUnderlineSpan):
formatting.underline = True
if isinstance(item, _WebVTTCueTextSpan):
text = item.text
else:
# TODO: address nesting
text = "".join(
[t.text for t in item.components if isinstance(t, _WebVTTCueTextSpan)]
)
if text := text.strip():
doc.add_text(
label=DocItemLabel.TEXT,
text=text,
parent=parent,
content_layer=ContentLayer.BODY,
formatting=formatting,
)
@override
def convert(self) -> DoclingDocument:
_log.debug("Starting WebVTT conversion...")
if not self.is_valid():
raise RuntimeError("Invalid WebVTT document.")
origin = DocumentOrigin(
filename=self.file.name or "file",
mimetype="text/vtt",
binary_hash=self.document_hash,
)
doc = DoclingDocument(name=self.file.stem or "file", origin=origin)
vtt: _WebVTTFile = _WebVTTFile.parse(self.content)
for block in vtt.cue_blocks:
block_group = doc.add_group(
label=GroupLabel.SECTION,
name="WebVTT cue block",
parent=None,
content_layer=ContentLayer.BODY,
)
if block.identifier:
doc.add_text(
label=DocItemLabel.TEXT,
text=str(block.identifier),
parent=block_group,
content_layer=ContentLayer.BODY,
)
doc.add_text(
label=DocItemLabel.TEXT,
text=str(block.timings),
parent=block_group,
content_layer=ContentLayer.BODY,
)
for cue_span in block.payload:
if isinstance(cue_span, _WebVTTCueVoiceSpan):
voice_group = doc.add_group(
label=GroupLabel.INLINE,
name="WebVTT cue voice span",
parent=block_group,
content_layer=ContentLayer.BODY,
)
voice = cue_span.annotation
if classes := cue_span.classes:
voice += f" ({', '.join(classes)})"
voice += ": "
doc.add_text(
label=DocItemLabel.TEXT,
text=voice,
parent=voice_group,
content_layer=ContentLayer.BODY,
)
for item in cue_span.components:
WebVTTDocumentBackend._add_text_from_component(
doc, item, voice_group
)
else:
WebVTTDocumentBackend._add_text_from_component(
doc, cue_span, block_group
)
return doc

View File

@@ -1,7 +1,6 @@
import math
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Optional, Type, Union
import numpy as np
from docling_core.types.doc import (
@@ -14,9 +13,7 @@ from docling_core.types.doc import (
)
from docling_core.types.doc.base import PydanticSerCtxKey, round_pydantic_float
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
from docling_core.types.io import (
DocumentStream,
)
from docling_core.types.io import DocumentStream
# DO NOT REMOVE; explicitly exposed from this location
from PIL.Image import Image
@@ -71,6 +68,7 @@ class InputFormat(str, Enum):
METS_GBS = "mets_gbs"
JSON_DOCLING = "json_docling"
AUDIO = "audio"
VTT = "vtt"
class OutputFormat(str, Enum):
@@ -82,7 +80,7 @@ class OutputFormat(str, Enum):
DOCTAGS = "doctags"
FormatToExtensions: Dict[InputFormat, List[str]] = {
FormatToExtensions: dict[InputFormat, list[str]] = {
InputFormat.DOCX: ["docx", "dotx", "docm", "dotm"],
InputFormat.PPTX: ["pptx", "potx", "ppsx", "pptm", "potm", "ppsm"],
InputFormat.PDF: ["pdf"],
@@ -97,9 +95,10 @@ FormatToExtensions: Dict[InputFormat, List[str]] = {
InputFormat.METS_GBS: ["tar.gz"],
InputFormat.JSON_DOCLING: ["json"],
InputFormat.AUDIO: ["wav", "mp3"],
InputFormat.VTT: ["vtt"],
}
FormatToMimeType: Dict[InputFormat, List[str]] = {
FormatToMimeType: dict[InputFormat, list[str]] = {
InputFormat.DOCX: [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
@@ -130,6 +129,7 @@ FormatToMimeType: Dict[InputFormat, List[str]] = {
InputFormat.METS_GBS: ["application/mets+xml"],
InputFormat.JSON_DOCLING: ["application/json"],
InputFormat.AUDIO: ["audio/x-wav", "audio/mpeg", "audio/wav", "audio/mp3"],
InputFormat.VTT: ["text/vtt"],
}
MimeTypeToFormat: dict[str, list[InputFormat]] = {
@@ -162,8 +162,8 @@ class Cluster(BaseModel):
label: DocItemLabel
bbox: BoundingBox
confidence: float = 1.0
cells: List[TextCell] = []
children: List["Cluster"] = [] # Add child cluster support
cells: list[TextCell] = []
children: list["Cluster"] = [] # Add child cluster support
@field_serializer("confidence")
def _serialize(self, value: float, info: FieldSerializationInfo) -> float:
@@ -179,7 +179,7 @@ class BasePageElement(BaseModel):
class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []
clusters: list[Cluster] = []
class VlmPredictionToken(BaseModel):
@@ -201,14 +201,14 @@ class ContainerElement(
class Table(BasePageElement):
otsl_seq: List[str]
otsl_seq: list[str]
num_rows: int = 0
num_cols: int = 0
table_cells: List[TableCell]
table_cells: list[TableCell]
class TableStructurePrediction(BaseModel):
table_map: Dict[int, Table] = {}
table_map: dict[int, Table] = {}
class TextElement(BasePageElement):
@@ -216,7 +216,7 @@ class TextElement(BasePageElement):
class FigureElement(BasePageElement):
annotations: List[PictureDataType] = []
annotations: list[PictureDataType] = []
provenance: Optional[str] = None
predicted_class: Optional[str] = None
confidence: Optional[float] = None
@@ -234,12 +234,12 @@ class FigureElement(BasePageElement):
class FigureClassificationPrediction(BaseModel):
figure_count: int = 0
figure_map: Dict[int, FigureElement] = {}
figure_map: dict[int, FigureElement] = {}
class EquationPrediction(BaseModel):
equation_count: int = 0
equation_map: Dict[int, TextElement] = {}
equation_map: dict[int, TextElement] = {}
class PagePredictions(BaseModel):
@@ -254,9 +254,9 @@ PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
class AssembledUnit(BaseModel):
elements: List[PageElement] = []
body: List[PageElement] = []
headers: List[PageElement] = []
elements: list[PageElement] = []
body: list[PageElement] = []
headers: list[PageElement] = []
class ItemAndImageEnrichmentElement(BaseModel):
@@ -280,12 +280,12 @@ class Page(BaseModel):
None # Internal PDF backend. By default it is cleared during assembling.
)
_default_image_scale: float = 1.0 # Default image scale for external usage.
_image_cache: Dict[
_image_cache: dict[
float, Image
] = {} # Cache of images in different scales. By default it is cleared during assembling.
@property
def cells(self) -> List[TextCell]:
def cells(self) -> list[TextCell]:
"""Return text cells as a read-only view of parsed_page.textline_cells."""
if self.parsed_page is not None:
return self.parsed_page.textline_cells
@@ -354,7 +354,7 @@ class OpenAiApiResponse(BaseModel):
id: str
model: Optional[str] = None # returned by openai
choices: List[OpenAiResponseChoice]
choices: list[OpenAiResponseChoice]
created: int
usage: OpenAiResponseUsage
@@ -430,7 +430,7 @@ class PageConfidenceScores(BaseModel):
class ConfidenceReport(PageConfidenceScores):
pages: Dict[int, PageConfidenceScores] = Field(
pages: dict[int, PageConfidenceScores] = Field(
default_factory=lambda: defaultdict(PageConfidenceScores)
)

View File

@@ -394,6 +394,8 @@ class _DocumentConversionInput(BaseModel):
mime = FormatToMimeType[InputFormat.PPTX][0]
elif ext in FormatToExtensions[InputFormat.XLSX]:
mime = FormatToMimeType[InputFormat.XLSX][0]
elif ext in FormatToExtensions[InputFormat.VTT]:
mime = FormatToMimeType[InputFormat.VTT][0]
return mime

View File

@@ -25,6 +25,7 @@ from docling.backend.msexcel_backend import MsExcelDocumentBackend
from docling.backend.mspowerpoint_backend import MsPowerpointDocumentBackend
from docling.backend.msword_backend import MsWordDocumentBackend
from docling.backend.noop_backend import NoOpBackend
from docling.backend.webvtt_backend import WebVTTDocumentBackend
from docling.backend.xml.jats_backend import JatsDocumentBackend
from docling.backend.xml.uspto_backend import PatentUsptoDocumentBackend
from docling.datamodel.base_models import (
@@ -170,6 +171,9 @@ def _get_default_option(format: InputFormat) -> FormatOption:
pipeline_cls=SimplePipeline, backend=DoclingJSONBackend
),
InputFormat.AUDIO: FormatOption(pipeline_cls=AsrPipeline, backend=NoOpBackend),
InputFormat.VTT: FormatOption(
pipeline_cls=SimplePipeline, backend=WebVTTDocumentBackend
),
}
if (options := format_to_default_options.get(format)) is not None:
return options