mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
* feat: add a backend parser for WebVTT files Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * docs: update README with VTT support Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * docs: add description to supported formats Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * chore: upgrade docling-core to unescape WebVTT in markdown Pin the new release of docling-core 2.48.2. Do not escape HTML reserved characters when exporting WebVTT documents to markdown. Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * test: add missing copyright notice Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> --------- Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>
233 lines
8.2 KiB
Python
233 lines
8.2 KiB
Python
# Assisted by watsonx Code Assistant
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from docling_core.types.doc import DoclingDocument
|
|
from pydantic import ValidationError
|
|
|
|
from docling.backend.webvtt_backend import (
|
|
_WebVTTCueItalicSpan,
|
|
_WebVTTCueTextSpan,
|
|
_WebVTTCueTimings,
|
|
_WebVTTCueVoiceSpan,
|
|
_WebVTTFile,
|
|
_WebVTTTimestamp,
|
|
)
|
|
from docling.datamodel.base_models import InputFormat
|
|
from docling.datamodel.document import ConversionResult
|
|
from docling.document_converter import DocumentConverter
|
|
|
|
from .test_data_gen_flag import GEN_TEST_DATA
|
|
from .verify_utils import verify_document, verify_export
|
|
|
|
GENERATE = GEN_TEST_DATA
|
|
|
|
|
|
def test_vtt_cue_commponents():
|
|
"""Test WebVTT components."""
|
|
valid_timestamps = [
|
|
"00:01:02.345",
|
|
"12:34:56.789",
|
|
"02:34.567",
|
|
"00:00:00.000",
|
|
]
|
|
valid_total_seconds = [
|
|
1 * 60 + 2.345,
|
|
12 * 3600 + 34 * 60 + 56.789,
|
|
2 * 60 + 34.567,
|
|
0.0,
|
|
]
|
|
for idx, ts in enumerate(valid_timestamps):
|
|
model = _WebVTTTimestamp(raw=ts)
|
|
assert model.seconds == valid_total_seconds[idx]
|
|
|
|
"""Test invalid WebVTT timestamps."""
|
|
invalid_timestamps = [
|
|
"00:60:02.345", # minutes > 59
|
|
"00:01:60.345", # seconds > 59
|
|
"00:01:02.1000", # milliseconds > 999
|
|
"01:02:03", # missing milliseconds
|
|
"01:02", # missing milliseconds
|
|
":01:02.345", # extra : for missing hours
|
|
"abc:01:02.345", # invalid format
|
|
]
|
|
for ts in invalid_timestamps:
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTTimestamp(raw=ts)
|
|
|
|
"""Test the timestamp __str__ method."""
|
|
model = _WebVTTTimestamp(raw="00:01:02.345")
|
|
assert str(model) == "00:01:02.345"
|
|
|
|
"""Test valid cue timings."""
|
|
start = _WebVTTTimestamp(raw="00:10.005")
|
|
end = _WebVTTTimestamp(raw="00:14.007")
|
|
cue_timings = _WebVTTCueTimings(start=start, end=end)
|
|
assert cue_timings.start == start
|
|
assert cue_timings.end == end
|
|
assert str(cue_timings) == "00:10.005 --> 00:14.007"
|
|
|
|
"""Test invalid cue timings with end timestamp before start."""
|
|
start = _WebVTTTimestamp(raw="00:10.700")
|
|
end = _WebVTTTimestamp(raw="00:10.500")
|
|
with pytest.raises(ValidationError) as excinfo:
|
|
_WebVTTCueTimings(start=start, end=end)
|
|
assert "End timestamp must be greater than start timestamp" in str(excinfo.value)
|
|
|
|
"""Test invalid cue timings with missing end."""
|
|
start = _WebVTTTimestamp(raw="00:10.500")
|
|
with pytest.raises(ValidationError) as excinfo:
|
|
_WebVTTCueTimings(start=start)
|
|
assert "Field required" in str(excinfo.value)
|
|
|
|
"""Test invalid cue timings with missing start."""
|
|
end = _WebVTTTimestamp(raw="00:10.500")
|
|
with pytest.raises(ValidationError) as excinfo:
|
|
_WebVTTCueTimings(end=end)
|
|
assert "Field required" in str(excinfo.value)
|
|
|
|
"""Test with valid text."""
|
|
valid_text = "This is a valid cue text span."
|
|
span = _WebVTTCueTextSpan(text=valid_text)
|
|
assert span.text == valid_text
|
|
assert str(span) == valid_text
|
|
|
|
"""Test with text containing newline characters."""
|
|
invalid_text = "This cue text span\ncontains a newline."
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTCueTextSpan(text=invalid_text)
|
|
|
|
"""Test with text containing ampersand."""
|
|
invalid_text = "This cue text span contains &."
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTCueTextSpan(text=invalid_text)
|
|
|
|
"""Test with text containing less-than sign."""
|
|
invalid_text = "This cue text span contains <."
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTCueTextSpan(text=invalid_text)
|
|
|
|
"""Test with empty text."""
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTCueTextSpan(text="")
|
|
|
|
"""Test that annotation validation works correctly."""
|
|
valid_annotation = "valid-annotation"
|
|
invalid_annotation = "invalid\nannotation"
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTCueVoiceSpan(annotation=invalid_annotation)
|
|
assert _WebVTTCueVoiceSpan(annotation=valid_annotation)
|
|
|
|
"""Test that classes validation works correctly."""
|
|
annotation = "speaker name"
|
|
valid_classes = ["class1", "class2"]
|
|
invalid_classes = ["class\nwith\nnewlines", ""]
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTCueVoiceSpan(annotation=annotation, classes=invalid_classes)
|
|
assert _WebVTTCueVoiceSpan(annotation=annotation, classes=valid_classes)
|
|
|
|
"""Test that components validation works correctly."""
|
|
annotation = "speaker name"
|
|
valid_components = [_WebVTTCueTextSpan(text="random text")]
|
|
invalid_components = [123, "not a component"]
|
|
with pytest.raises(ValidationError):
|
|
_WebVTTCueVoiceSpan(annotation=annotation, components=invalid_components)
|
|
assert _WebVTTCueVoiceSpan(annotation=annotation, components=valid_components)
|
|
|
|
"""Test valid cue voice spans."""
|
|
cue_span = _WebVTTCueVoiceSpan(
|
|
annotation="speaker",
|
|
classes=["loud", "clear"],
|
|
components=[_WebVTTCueTextSpan(text="random text")],
|
|
)
|
|
|
|
expected_str = "<v.loud.clear speaker>random text</v>"
|
|
assert str(cue_span) == expected_str
|
|
|
|
cue_span = _WebVTTCueVoiceSpan(
|
|
annotation="speaker",
|
|
components=[_WebVTTCueTextSpan(text="random text")],
|
|
)
|
|
expected_str = "<v speaker>random text</v>"
|
|
assert str(cue_span) == expected_str
|
|
|
|
|
|
def test_webvtt_file():
|
|
"""Test WebVTT files."""
|
|
with open("./tests/data/webvtt/webvtt_example_01.vtt", encoding="utf-8") as f:
|
|
content = f.read()
|
|
vtt = _WebVTTFile.parse(content)
|
|
assert len(vtt) == 13
|
|
block = vtt.cue_blocks[11]
|
|
assert str(block.timings) == "00:32.500 --> 00:33.500"
|
|
assert len(block.payload) == 1
|
|
cue_span = block.payload[0]
|
|
assert isinstance(cue_span, _WebVTTCueVoiceSpan)
|
|
assert cue_span.annotation == "Neil deGrasse Tyson"
|
|
assert not cue_span.classes
|
|
assert len(cue_span.components) == 1
|
|
comp = cue_span.components[0]
|
|
assert isinstance(comp, _WebVTTCueItalicSpan)
|
|
assert len(comp.components) == 1
|
|
comp2 = comp.components[0]
|
|
assert isinstance(comp2, _WebVTTCueTextSpan)
|
|
assert comp2.text == "Laughs"
|
|
|
|
with open("./tests/data/webvtt/webvtt_example_02.vtt", encoding="utf-8") as f:
|
|
content = f.read()
|
|
vtt = _WebVTTFile.parse(content)
|
|
assert len(vtt) == 4
|
|
reverse = (
|
|
"WEBVTT\n\nNOTE Copyright © 2019 World Wide Web Consortium. "
|
|
"https://www.w3.org/TR/webvtt1/\n\n"
|
|
)
|
|
reverse += "\n\n".join([str(block) for block in vtt.cue_blocks])
|
|
assert content == reverse
|
|
|
|
with open("./tests/data/webvtt/webvtt_example_03.vtt", encoding="utf-8") as f:
|
|
content = f.read()
|
|
vtt = _WebVTTFile.parse(content)
|
|
assert len(vtt) == 13
|
|
for block in vtt:
|
|
assert block.identifier
|
|
block = vtt.cue_blocks[0]
|
|
assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-0"
|
|
assert str(block.timings) == "00:00:04.963 --> 00:00:08.571"
|
|
assert len(block.payload) == 1
|
|
assert isinstance(block.payload[0], _WebVTTCueVoiceSpan)
|
|
block = vtt.cue_blocks[2]
|
|
assert isinstance(cue_span, _WebVTTCueVoiceSpan)
|
|
assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/16-0"
|
|
assert str(block.timings) == "00:00:10.683 --> 00:00:11.563"
|
|
assert len(block.payload) == 1
|
|
assert isinstance(block.payload[0], _WebVTTCueTextSpan)
|
|
assert block.payload[0].text == "Good."
|
|
|
|
|
|
def test_e2e_vtt_conversions():
|
|
directory = Path("./tests/data/webvtt/")
|
|
vtt_paths = sorted(directory.rglob("*.vtt"))
|
|
converter = DocumentConverter(allowed_formats=[InputFormat.VTT])
|
|
|
|
for vtt in vtt_paths:
|
|
gt_path = vtt.parent.parent / "groundtruth" / "docling_v2" / vtt.name
|
|
|
|
conv_result: ConversionResult = converter.convert(vtt)
|
|
|
|
doc: DoclingDocument = conv_result.document
|
|
|
|
pred_md: str = doc.export_to_markdown(escape_html=False)
|
|
assert verify_export(pred_md, str(gt_path) + ".md", generate=GENERATE), (
|
|
"export to md"
|
|
)
|
|
|
|
pred_itxt: str = doc._export_to_indented_text(
|
|
max_text_len=70, explicit_tables=False
|
|
)
|
|
assert verify_export(pred_itxt, str(gt_path) + ".itxt", generate=GENERATE), (
|
|
"export to indented-text"
|
|
)
|
|
|
|
assert verify_document(doc, str(gt_path) + ".json", GENERATE)
|