Files
docling/tests/test_backend_vtt.py
Cesar Berrospi Ramis 46efaaefee feat: add a backend parser for WebVTT files (#2288)
* feat: add a backend parser for WebVTT files

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* docs: update README with VTT support

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* docs: add description to supported formats

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* chore: upgrade docling-core to unescape WebVTT in markdown

Pin the new release of docling-core 2.48.2.
Do not escape HTML reserved characters when exporting WebVTT documents to markdown.

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

* test: add missing copyright notice

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>

---------

Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>
2025-09-22 15:24:34 +02:00

233 lines
8.2 KiB
Python

# Assisted by watsonx Code Assistant
from pathlib import Path
import pytest
from docling_core.types.doc import DoclingDocument
from pydantic import ValidationError
from docling.backend.webvtt_backend import (
_WebVTTCueItalicSpan,
_WebVTTCueTextSpan,
_WebVTTCueTimings,
_WebVTTCueVoiceSpan,
_WebVTTFile,
_WebVTTTimestamp,
)
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import ConversionResult
from docling.document_converter import DocumentConverter
from .test_data_gen_flag import GEN_TEST_DATA
from .verify_utils import verify_document, verify_export
GENERATE = GEN_TEST_DATA
def test_vtt_cue_commponents():
"""Test WebVTT components."""
valid_timestamps = [
"00:01:02.345",
"12:34:56.789",
"02:34.567",
"00:00:00.000",
]
valid_total_seconds = [
1 * 60 + 2.345,
12 * 3600 + 34 * 60 + 56.789,
2 * 60 + 34.567,
0.0,
]
for idx, ts in enumerate(valid_timestamps):
model = _WebVTTTimestamp(raw=ts)
assert model.seconds == valid_total_seconds[idx]
"""Test invalid WebVTT timestamps."""
invalid_timestamps = [
"00:60:02.345", # minutes > 59
"00:01:60.345", # seconds > 59
"00:01:02.1000", # milliseconds > 999
"01:02:03", # missing milliseconds
"01:02", # missing milliseconds
":01:02.345", # extra : for missing hours
"abc:01:02.345", # invalid format
]
for ts in invalid_timestamps:
with pytest.raises(ValidationError):
_WebVTTTimestamp(raw=ts)
"""Test the timestamp __str__ method."""
model = _WebVTTTimestamp(raw="00:01:02.345")
assert str(model) == "00:01:02.345"
"""Test valid cue timings."""
start = _WebVTTTimestamp(raw="00:10.005")
end = _WebVTTTimestamp(raw="00:14.007")
cue_timings = _WebVTTCueTimings(start=start, end=end)
assert cue_timings.start == start
assert cue_timings.end == end
assert str(cue_timings) == "00:10.005 --> 00:14.007"
"""Test invalid cue timings with end timestamp before start."""
start = _WebVTTTimestamp(raw="00:10.700")
end = _WebVTTTimestamp(raw="00:10.500")
with pytest.raises(ValidationError) as excinfo:
_WebVTTCueTimings(start=start, end=end)
assert "End timestamp must be greater than start timestamp" in str(excinfo.value)
"""Test invalid cue timings with missing end."""
start = _WebVTTTimestamp(raw="00:10.500")
with pytest.raises(ValidationError) as excinfo:
_WebVTTCueTimings(start=start)
assert "Field required" in str(excinfo.value)
"""Test invalid cue timings with missing start."""
end = _WebVTTTimestamp(raw="00:10.500")
with pytest.raises(ValidationError) as excinfo:
_WebVTTCueTimings(end=end)
assert "Field required" in str(excinfo.value)
"""Test with valid text."""
valid_text = "This is a valid cue text span."
span = _WebVTTCueTextSpan(text=valid_text)
assert span.text == valid_text
assert str(span) == valid_text
"""Test with text containing newline characters."""
invalid_text = "This cue text span\ncontains a newline."
with pytest.raises(ValidationError):
_WebVTTCueTextSpan(text=invalid_text)
"""Test with text containing ampersand."""
invalid_text = "This cue text span contains &."
with pytest.raises(ValidationError):
_WebVTTCueTextSpan(text=invalid_text)
"""Test with text containing less-than sign."""
invalid_text = "This cue text span contains <."
with pytest.raises(ValidationError):
_WebVTTCueTextSpan(text=invalid_text)
"""Test with empty text."""
with pytest.raises(ValidationError):
_WebVTTCueTextSpan(text="")
"""Test that annotation validation works correctly."""
valid_annotation = "valid-annotation"
invalid_annotation = "invalid\nannotation"
with pytest.raises(ValidationError):
_WebVTTCueVoiceSpan(annotation=invalid_annotation)
assert _WebVTTCueVoiceSpan(annotation=valid_annotation)
"""Test that classes validation works correctly."""
annotation = "speaker name"
valid_classes = ["class1", "class2"]
invalid_classes = ["class\nwith\nnewlines", ""]
with pytest.raises(ValidationError):
_WebVTTCueVoiceSpan(annotation=annotation, classes=invalid_classes)
assert _WebVTTCueVoiceSpan(annotation=annotation, classes=valid_classes)
"""Test that components validation works correctly."""
annotation = "speaker name"
valid_components = [_WebVTTCueTextSpan(text="random text")]
invalid_components = [123, "not a component"]
with pytest.raises(ValidationError):
_WebVTTCueVoiceSpan(annotation=annotation, components=invalid_components)
assert _WebVTTCueVoiceSpan(annotation=annotation, components=valid_components)
"""Test valid cue voice spans."""
cue_span = _WebVTTCueVoiceSpan(
annotation="speaker",
classes=["loud", "clear"],
components=[_WebVTTCueTextSpan(text="random text")],
)
expected_str = "<v.loud.clear speaker>random text</v>"
assert str(cue_span) == expected_str
cue_span = _WebVTTCueVoiceSpan(
annotation="speaker",
components=[_WebVTTCueTextSpan(text="random text")],
)
expected_str = "<v speaker>random text</v>"
assert str(cue_span) == expected_str
def test_webvtt_file():
"""Test WebVTT files."""
with open("./tests/data/webvtt/webvtt_example_01.vtt", encoding="utf-8") as f:
content = f.read()
vtt = _WebVTTFile.parse(content)
assert len(vtt) == 13
block = vtt.cue_blocks[11]
assert str(block.timings) == "00:32.500 --> 00:33.500"
assert len(block.payload) == 1
cue_span = block.payload[0]
assert isinstance(cue_span, _WebVTTCueVoiceSpan)
assert cue_span.annotation == "Neil deGrasse Tyson"
assert not cue_span.classes
assert len(cue_span.components) == 1
comp = cue_span.components[0]
assert isinstance(comp, _WebVTTCueItalicSpan)
assert len(comp.components) == 1
comp2 = comp.components[0]
assert isinstance(comp2, _WebVTTCueTextSpan)
assert comp2.text == "Laughs"
with open("./tests/data/webvtt/webvtt_example_02.vtt", encoding="utf-8") as f:
content = f.read()
vtt = _WebVTTFile.parse(content)
assert len(vtt) == 4
reverse = (
"WEBVTT\n\nNOTE Copyright © 2019 World Wide Web Consortium. "
"https://www.w3.org/TR/webvtt1/\n\n"
)
reverse += "\n\n".join([str(block) for block in vtt.cue_blocks])
assert content == reverse
with open("./tests/data/webvtt/webvtt_example_03.vtt", encoding="utf-8") as f:
content = f.read()
vtt = _WebVTTFile.parse(content)
assert len(vtt) == 13
for block in vtt:
assert block.identifier
block = vtt.cue_blocks[0]
assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-0"
assert str(block.timings) == "00:00:04.963 --> 00:00:08.571"
assert len(block.payload) == 1
assert isinstance(block.payload[0], _WebVTTCueVoiceSpan)
block = vtt.cue_blocks[2]
assert isinstance(cue_span, _WebVTTCueVoiceSpan)
assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/16-0"
assert str(block.timings) == "00:00:10.683 --> 00:00:11.563"
assert len(block.payload) == 1
assert isinstance(block.payload[0], _WebVTTCueTextSpan)
assert block.payload[0].text == "Good."
def test_e2e_vtt_conversions():
directory = Path("./tests/data/webvtt/")
vtt_paths = sorted(directory.rglob("*.vtt"))
converter = DocumentConverter(allowed_formats=[InputFormat.VTT])
for vtt in vtt_paths:
gt_path = vtt.parent.parent / "groundtruth" / "docling_v2" / vtt.name
conv_result: ConversionResult = converter.convert(vtt)
doc: DoclingDocument = conv_result.document
pred_md: str = doc.export_to_markdown(escape_html=False)
assert verify_export(pred_md, str(gt_path) + ".md", generate=GENERATE), (
"export to md"
)
pred_itxt: str = doc._export_to_indented_text(
max_text_len=70, explicit_tables=False
)
assert verify_export(pred_itxt, str(gt_path) + ".itxt", generate=GENERATE), (
"export to indented-text"
)
assert verify_document(doc, str(gt_path) + ".json", GENERATE)