mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-11 22:28:31 +00:00
feat: add a backend parser for WebVTT files (#2288)
* feat: add a backend parser for WebVTT files Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * docs: update README with VTT support Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * docs: add description to supported formats Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * chore: upgrade docling-core to unescape WebVTT in markdown Pin the new release of docling-core 2.48.2. Do not escape HTML reserved characters when exporting WebVTT documents to markdown. Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> * test: add missing copyright notice Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com> --------- Signed-off-by: Cesar Berrospi Ramis <ceb@zurich.ibm.com>
This commit is contained in:
committed by
GitHub
parent
b5628f1227
commit
46efaaefee
572
docling/backend/webvtt_backend.py
Normal file
572
docling/backend/webvtt_backend.py
Normal file
@@ -0,0 +1,572 @@
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Annotated, ClassVar, Literal, Optional, Union, cast
|
||||
|
||||
from docling_core.types.doc import (
|
||||
ContentLayer,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
DocumentOrigin,
|
||||
Formatting,
|
||||
GroupLabel,
|
||||
NodeItem,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.types import StringConstraints
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from docling.backend.abstract_backend import DeclarativeDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.document import InputDocument
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WebVTTTimestamp(BaseModel):
|
||||
"""Model representing a WebVTT timestamp.
|
||||
|
||||
A WebVTT timestamp is always interpreted relative to the current playback position
|
||||
of the media data that the WebVTT file is to be synchronized with.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(regex_engine="python-re")
|
||||
|
||||
raw: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description="A representation of the WebVTT Timestamp as a single string"
|
||||
),
|
||||
]
|
||||
|
||||
_pattern: ClassVar[re.Pattern] = re.compile(
|
||||
r"^(?:(\d{2,}):)?([0-5]\d):([0-5]\d)\.(\d{3})$"
|
||||
)
|
||||
_hours: int
|
||||
_minutes: int
|
||||
_seconds: int
|
||||
_millis: int
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_raw(self) -> Self:
|
||||
m = self._pattern.match(self.raw)
|
||||
if not m:
|
||||
raise ValueError(f"Invalid WebVTT timestamp format: {self.raw}")
|
||||
self._hours = int(m.group(1)) if m.group(1) else 0
|
||||
self._minutes = int(m.group(2))
|
||||
self._seconds = int(m.group(3))
|
||||
self._millis = int(m.group(4))
|
||||
|
||||
if self._minutes < 0 or self._minutes > 59:
|
||||
raise ValueError("Minutes must be between 0 and 59")
|
||||
if self._seconds < 0 or self._seconds > 59:
|
||||
raise ValueError("Seconds must be between 0 and 59")
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def seconds(self) -> float:
|
||||
"""A representation of the WebVTT Timestamp in seconds"""
|
||||
return (
|
||||
self._hours * 3600
|
||||
+ self._minutes * 60
|
||||
+ self._seconds
|
||||
+ self._millis / 1000.0
|
||||
)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.raw
|
||||
|
||||
|
||||
_WebVTTCueIdentifier = Annotated[
|
||||
str, StringConstraints(strict=True, pattern=r"^(?!.*-->)[^\n\r]+$")
|
||||
]
|
||||
|
||||
|
||||
class _WebVTTCueTimings(BaseModel):
|
||||
"""Model representating WebVTT cue timings."""
|
||||
|
||||
start: Annotated[
|
||||
_WebVTTTimestamp, Field(description="Start time offset of the cue")
|
||||
]
|
||||
end: Annotated[_WebVTTTimestamp, Field(description="End time offset of the cue")]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_order(self) -> Self:
|
||||
if self.start and self.end:
|
||||
if self.end.seconds <= self.start.seconds:
|
||||
raise ValueError("End timestamp must be greater than start timestamp")
|
||||
return self
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
return f"{self.start} --> {self.end}"
|
||||
|
||||
|
||||
class _WebVTTCueTextSpan(BaseModel):
|
||||
"""Model representing a WebVTT cue text span."""
|
||||
|
||||
text: str
|
||||
span_type: Literal["text"] = "text"
|
||||
|
||||
@field_validator("text", mode="after")
|
||||
@classmethod
|
||||
def validate_text(cls, value: str) -> str:
|
||||
if any(ch in value for ch in {"\n", "\r", "&", "<"}):
|
||||
raise ValueError("Cue text span contains invalid characters")
|
||||
if len(value) == 0:
|
||||
raise ValueError("Cue text span cannot be empty")
|
||||
return value
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
|
||||
class _WebVTTCueVoiceSpan(BaseModel):
|
||||
"""Model representing a WebVTT cue voice span."""
|
||||
|
||||
annotation: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description=(
|
||||
"Cue span start tag annotation text representing the name of thevoice"
|
||||
)
|
||||
),
|
||||
]
|
||||
classes: Annotated[
|
||||
list[str],
|
||||
Field(description="List of classes representing the cue span's significance"),
|
||||
] = []
|
||||
components: Annotated[
|
||||
list["_WebVTTCueComponent"],
|
||||
Field(description="The components representing the cue internal text"),
|
||||
] = []
|
||||
span_type: Literal["v"] = "v"
|
||||
|
||||
@field_validator("annotation", mode="after")
|
||||
@classmethod
|
||||
def validate_annotation(cls, value: str) -> str:
|
||||
if any(ch in value for ch in {"\n", "\r", "&", ">"}):
|
||||
raise ValueError(
|
||||
"Cue span start tag annotation contains invalid characters"
|
||||
)
|
||||
if not value:
|
||||
raise ValueError("Cue text span cannot be empty")
|
||||
return value
|
||||
|
||||
@field_validator("classes", mode="after")
|
||||
@classmethod
|
||||
def validate_classes(cls, value: list[str]) -> list[str]:
|
||||
for item in value:
|
||||
if any(ch in item for ch in {"\t", "\n", "\r", " ", "&", "<", ">", "."}):
|
||||
raise ValueError(
|
||||
"A cue span start tag class contains invalid characters"
|
||||
)
|
||||
if not item:
|
||||
raise ValueError("Cue span start tag classes cannot be empty")
|
||||
return value
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
tag = f"v.{'.'.join(self.classes)}" if self.classes else "v"
|
||||
inner = "".join(str(span) for span in self.components)
|
||||
return f"<{tag} {self.annotation}>{inner}</v>"
|
||||
|
||||
|
||||
class _WebVTTCueClassSpan(BaseModel):
|
||||
span_type: Literal["c"] = "c"
|
||||
components: list["_WebVTTCueComponent"]
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
inner = "".join(str(span) for span in self.components)
|
||||
return f"<c>{inner}</c>"
|
||||
|
||||
|
||||
class _WebVTTCueItalicSpan(BaseModel):
|
||||
span_type: Literal["i"] = "i"
|
||||
components: list["_WebVTTCueComponent"]
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
inner = "".join(str(span) for span in self.components)
|
||||
return f"<i>{inner}</i>"
|
||||
|
||||
|
||||
class _WebVTTCueBoldSpan(BaseModel):
|
||||
span_type: Literal["b"] = "b"
|
||||
components: list["_WebVTTCueComponent"]
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
inner = "".join(str(span) for span in self.components)
|
||||
return f"<b>{inner}</b>"
|
||||
|
||||
|
||||
class _WebVTTCueUnderlineSpan(BaseModel):
|
||||
span_type: Literal["u"] = "u"
|
||||
components: list["_WebVTTCueComponent"]
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
inner = "".join(str(span) for span in self.components)
|
||||
return f"<u>{inner}</u>"
|
||||
|
||||
|
||||
_WebVTTCueComponent = Annotated[
|
||||
Union[
|
||||
_WebVTTCueTextSpan,
|
||||
_WebVTTCueClassSpan,
|
||||
_WebVTTCueItalicSpan,
|
||||
_WebVTTCueBoldSpan,
|
||||
_WebVTTCueUnderlineSpan,
|
||||
_WebVTTCueVoiceSpan,
|
||||
],
|
||||
Field(discriminator="span_type", description="The WebVTT cue component"),
|
||||
]
|
||||
|
||||
|
||||
class _WebVTTCueBlock(BaseModel):
|
||||
"""Model representing a WebVTT cue block.
|
||||
|
||||
The optional WebVTT cue settings list is not supported.
|
||||
The cue payload is limited to the following spans: text, class, italic, bold,
|
||||
underline, and voice.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(regex_engine="python-re")
|
||||
|
||||
identifier: Optional[_WebVTTCueIdentifier] = Field(
|
||||
None, description="The WebVTT cue identifier"
|
||||
)
|
||||
timings: Annotated[_WebVTTCueTimings, Field(description="The WebVTT cue timings")]
|
||||
payload: Annotated[list[_WebVTTCueComponent], Field(description="The cue payload")]
|
||||
|
||||
_pattern_block: ClassVar[re.Pattern] = re.compile(
|
||||
r"<(/?)(i|b|c|u|v(?:\.[^\t\n\r &<>.]+)*)(?:\s+([^>]*))?>"
|
||||
)
|
||||
_pattern_voice_tag: ClassVar[re.Pattern] = re.compile(
|
||||
r"^<v(?P<class>\.[^\t\n\r &<>]+)?" # zero or more classes
|
||||
r"[ \t]+(?P<annotation>[^\n\r&>]+)>" # required space and annotation
|
||||
)
|
||||
|
||||
@field_validator("payload", mode="after")
|
||||
@classmethod
|
||||
def validate_payload(cls, payload):
|
||||
for voice in payload:
|
||||
if "-->" in str(voice):
|
||||
raise ValueError("Cue payload must not contain '-->'")
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def parse(cls, raw: str) -> "_WebVTTCueBlock":
|
||||
lines = raw.strip().splitlines()
|
||||
if not lines:
|
||||
raise ValueError("Cue block must have at least one line")
|
||||
identifier: Optional[_WebVTTCueIdentifier] = None
|
||||
timing_line = lines[0]
|
||||
if "-->" not in timing_line and len(lines) > 1:
|
||||
identifier = timing_line
|
||||
timing_line = lines[1]
|
||||
cue_lines = lines[2:]
|
||||
else:
|
||||
cue_lines = lines[1:]
|
||||
|
||||
if "-->" not in timing_line:
|
||||
raise ValueError("Cue block must contain WebVTT cue timings")
|
||||
|
||||
start, end = [t.strip() for t in timing_line.split("-->")]
|
||||
end = re.split(" |\t", end)[0] # ignore the cue settings list
|
||||
timings: _WebVTTCueTimings = _WebVTTCueTimings(
|
||||
start=_WebVTTTimestamp(raw=start), end=_WebVTTTimestamp(raw=end)
|
||||
)
|
||||
cue_text = " ".join(cue_lines).strip()
|
||||
if cue_text.startswith("<v") and "</v>" not in cue_text:
|
||||
# adding close tag for cue voice spans without end tag
|
||||
cue_text += "</v>"
|
||||
|
||||
stack: list[list[_WebVTTCueComponent]] = [[]]
|
||||
tag_stack: list[Union[str, tuple]] = []
|
||||
|
||||
pos = 0
|
||||
matches = list(cls._pattern_block.finditer(cue_text))
|
||||
i = 0
|
||||
while i < len(matches):
|
||||
match = matches[i]
|
||||
if match.start() > pos:
|
||||
stack[-1].append(_WebVTTCueTextSpan(text=cue_text[pos : match.start()]))
|
||||
tag = match.group(0)
|
||||
|
||||
if tag.startswith(("<i>", "<b>", "<u>", "<c>")):
|
||||
tag_type = tag[1:2]
|
||||
tag_stack.append(tag_type)
|
||||
stack.append([])
|
||||
elif tag == "</i>":
|
||||
children = stack.pop()
|
||||
stack[-1].append(_WebVTTCueItalicSpan(components=children))
|
||||
tag_stack.pop()
|
||||
elif tag == "</b>":
|
||||
children = stack.pop()
|
||||
stack[-1].append(_WebVTTCueBoldSpan(components=children))
|
||||
tag_stack.pop()
|
||||
elif tag == "</u>":
|
||||
children = stack.pop()
|
||||
stack[-1].append(_WebVTTCueUnderlineSpan(components=children))
|
||||
tag_stack.pop()
|
||||
elif tag == "</c>":
|
||||
children = stack.pop()
|
||||
stack[-1].append(_WebVTTCueClassSpan(components=children))
|
||||
tag_stack.pop()
|
||||
elif tag.startswith("<v"):
|
||||
tag_stack.append(("v", tag))
|
||||
stack.append([])
|
||||
elif tag.startswith("</v"):
|
||||
children = stack.pop() if stack else []
|
||||
if (
|
||||
tag_stack
|
||||
and isinstance(tag_stack[-1], tuple)
|
||||
and tag_stack[-1][0] == "v"
|
||||
):
|
||||
_, voice = cast(tuple, tag_stack.pop())
|
||||
voice_match = cls._pattern_voice_tag.match(voice)
|
||||
if voice_match:
|
||||
class_string = voice_match.group("class")
|
||||
annotation = voice_match.group("annotation")
|
||||
if annotation:
|
||||
classes: list[str] = []
|
||||
if class_string:
|
||||
classes = [c for c in class_string.split(".") if c]
|
||||
stack[-1].append(
|
||||
_WebVTTCueVoiceSpan(
|
||||
annotation=annotation.strip(),
|
||||
classes=classes,
|
||||
components=children,
|
||||
)
|
||||
)
|
||||
|
||||
pos = match.end()
|
||||
i += 1
|
||||
|
||||
if pos < len(cue_text):
|
||||
stack[-1].append(_WebVTTCueTextSpan(text=cue_text[pos:]))
|
||||
|
||||
return cls(
|
||||
identifier=identifier,
|
||||
timings=timings,
|
||||
payload=stack[0],
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
parts = []
|
||||
if self.identifier:
|
||||
parts.append(f"{self.identifier}\n")
|
||||
timings_line = str(self.timings)
|
||||
parts.append(timings_line + "\n")
|
||||
for idx, span in enumerate(self.payload):
|
||||
if idx == 0 and len(self.payload) == 1 and span.span_type == "v":
|
||||
# the end tag may be omitted for brevity
|
||||
parts.append(str(span).removesuffix("</v>"))
|
||||
else:
|
||||
parts.append(str(span))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class _WebVTTFile(BaseModel):
|
||||
"""A model representing a WebVTT file."""
|
||||
|
||||
cue_blocks: list[_WebVTTCueBlock]
|
||||
|
||||
@staticmethod
|
||||
def verify_signature(content: str) -> bool:
|
||||
if not content:
|
||||
return False
|
||||
elif len(content) == 6:
|
||||
return content == "WEBVTT"
|
||||
elif len(content) > 6 and content.startswith("WEBVTT"):
|
||||
return content[6] in (" ", "\t", "\n")
|
||||
else:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse(cls, raw: str) -> "_WebVTTFile":
|
||||
# Normalize newlines to LF
|
||||
raw = raw.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
# Check WebVTT signature
|
||||
if not cls.verify_signature(raw):
|
||||
raise ValueError("Invalid WebVTT file signature")
|
||||
|
||||
# Strip "WEBVTT" header line
|
||||
lines = raw.split("\n", 1)
|
||||
body = lines[1] if len(lines) > 1 else ""
|
||||
|
||||
# Remove NOTE/STYLE/REGION blocks
|
||||
body = re.sub(r"^(NOTE[^\n]*\n(?:.+\n)*?)\n", "", body, flags=re.MULTILINE)
|
||||
body = re.sub(r"^(STYLE|REGION)(?:.+\n)*?\n", "", body, flags=re.MULTILINE)
|
||||
|
||||
# Split into cue blocks
|
||||
raw_blocks = re.split(r"\n\s*\n", body.strip())
|
||||
cues: list[_WebVTTCueBlock] = []
|
||||
for block in raw_blocks:
|
||||
try:
|
||||
cues.append(_WebVTTCueBlock.parse(block))
|
||||
except ValueError as e:
|
||||
_log.warning(f"Failed to parse cue block:\n{block}\n{e}")
|
||||
|
||||
return cls(cue_blocks=cues)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cue_blocks)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.cue_blocks[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cue_blocks)
|
||||
|
||||
|
||||
class WebVTTDocumentBackend(DeclarativeDocumentBackend):
|
||||
"""Declarative backend for WebVTT (.vtt) files.
|
||||
|
||||
This parser reads the content of a WebVTT file and converts
|
||||
it to a DoclingDocument, following the W3C specs on https://www.w3.org/TR/webvtt1
|
||||
|
||||
Each cue becomes a TextItem and the items are appended to the
|
||||
document body by the cue's start time.
|
||||
"""
|
||||
|
||||
@override
|
||||
def __init__(self, in_doc: InputDocument, path_or_stream: Union[BytesIO, Path]):
|
||||
super().__init__(in_doc, path_or_stream)
|
||||
|
||||
self.content: str = ""
|
||||
try:
|
||||
if isinstance(self.path_or_stream, BytesIO):
|
||||
self.content = self.path_or_stream.getvalue().decode("utf-8")
|
||||
if isinstance(self.path_or_stream, Path):
|
||||
with open(self.path_or_stream, encoding="utf-8") as f:
|
||||
self.content = f.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Could not initialize the WebVTT backend for file with hash "
|
||||
f"{self.document_hash}."
|
||||
) from e
|
||||
|
||||
@override
|
||||
def is_valid(self) -> bool:
|
||||
return _WebVTTFile.verify_signature(self.content)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def supports_pagination(cls) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def unload(self):
|
||||
if isinstance(self.path_or_stream, BytesIO):
|
||||
self.path_or_stream.close()
|
||||
self.path_or_stream = None
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def supported_formats(cls) -> set[InputFormat]:
|
||||
return {InputFormat.VTT}
|
||||
|
||||
@staticmethod
|
||||
def _add_text_from_component(
|
||||
doc: DoclingDocument, item: _WebVTTCueComponent, parent: Optional[NodeItem]
|
||||
) -> None:
|
||||
"""Adds a TextItem to a document by extracting text from a cue span component.
|
||||
|
||||
TODO: address nesting
|
||||
"""
|
||||
formatting = Formatting()
|
||||
text = ""
|
||||
if isinstance(item, _WebVTTCueItalicSpan):
|
||||
formatting.italic = True
|
||||
elif isinstance(item, _WebVTTCueBoldSpan):
|
||||
formatting.bold = True
|
||||
elif isinstance(item, _WebVTTCueUnderlineSpan):
|
||||
formatting.underline = True
|
||||
if isinstance(item, _WebVTTCueTextSpan):
|
||||
text = item.text
|
||||
else:
|
||||
# TODO: address nesting
|
||||
text = "".join(
|
||||
[t.text for t in item.components if isinstance(t, _WebVTTCueTextSpan)]
|
||||
)
|
||||
if text := text.strip():
|
||||
doc.add_text(
|
||||
label=DocItemLabel.TEXT,
|
||||
text=text,
|
||||
parent=parent,
|
||||
content_layer=ContentLayer.BODY,
|
||||
formatting=formatting,
|
||||
)
|
||||
|
||||
@override
|
||||
def convert(self) -> DoclingDocument:
|
||||
_log.debug("Starting WebVTT conversion...")
|
||||
if not self.is_valid():
|
||||
raise RuntimeError("Invalid WebVTT document.")
|
||||
|
||||
origin = DocumentOrigin(
|
||||
filename=self.file.name or "file",
|
||||
mimetype="text/vtt",
|
||||
binary_hash=self.document_hash,
|
||||
)
|
||||
doc = DoclingDocument(name=self.file.stem or "file", origin=origin)
|
||||
|
||||
vtt: _WebVTTFile = _WebVTTFile.parse(self.content)
|
||||
for block in vtt.cue_blocks:
|
||||
block_group = doc.add_group(
|
||||
label=GroupLabel.SECTION,
|
||||
name="WebVTT cue block",
|
||||
parent=None,
|
||||
content_layer=ContentLayer.BODY,
|
||||
)
|
||||
if block.identifier:
|
||||
doc.add_text(
|
||||
label=DocItemLabel.TEXT,
|
||||
text=str(block.identifier),
|
||||
parent=block_group,
|
||||
content_layer=ContentLayer.BODY,
|
||||
)
|
||||
doc.add_text(
|
||||
label=DocItemLabel.TEXT,
|
||||
text=str(block.timings),
|
||||
parent=block_group,
|
||||
content_layer=ContentLayer.BODY,
|
||||
)
|
||||
for cue_span in block.payload:
|
||||
if isinstance(cue_span, _WebVTTCueVoiceSpan):
|
||||
voice_group = doc.add_group(
|
||||
label=GroupLabel.INLINE,
|
||||
name="WebVTT cue voice span",
|
||||
parent=block_group,
|
||||
content_layer=ContentLayer.BODY,
|
||||
)
|
||||
voice = cue_span.annotation
|
||||
if classes := cue_span.classes:
|
||||
voice += f" ({', '.join(classes)})"
|
||||
voice += ": "
|
||||
doc.add_text(
|
||||
label=DocItemLabel.TEXT,
|
||||
text=voice,
|
||||
parent=voice_group,
|
||||
content_layer=ContentLayer.BODY,
|
||||
)
|
||||
for item in cue_span.components:
|
||||
WebVTTDocumentBackend._add_text_from_component(
|
||||
doc, item, voice_group
|
||||
)
|
||||
else:
|
||||
WebVTTDocumentBackend._add_text_from_component(
|
||||
doc, cue_span, block_group
|
||||
)
|
||||
|
||||
return doc
|
||||
@@ -1,7 +1,6 @@
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import (
|
||||
@@ -14,9 +13,7 @@ from docling_core.types.doc import (
|
||||
)
|
||||
from docling_core.types.doc.base import PydanticSerCtxKey, round_pydantic_float
|
||||
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
|
||||
from docling_core.types.io import (
|
||||
DocumentStream,
|
||||
)
|
||||
from docling_core.types.io import DocumentStream
|
||||
|
||||
# DO NOT REMOVE; explicitly exposed from this location
|
||||
from PIL.Image import Image
|
||||
@@ -71,6 +68,7 @@ class InputFormat(str, Enum):
|
||||
METS_GBS = "mets_gbs"
|
||||
JSON_DOCLING = "json_docling"
|
||||
AUDIO = "audio"
|
||||
VTT = "vtt"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
@@ -82,7 +80,7 @@ class OutputFormat(str, Enum):
|
||||
DOCTAGS = "doctags"
|
||||
|
||||
|
||||
FormatToExtensions: Dict[InputFormat, List[str]] = {
|
||||
FormatToExtensions: dict[InputFormat, list[str]] = {
|
||||
InputFormat.DOCX: ["docx", "dotx", "docm", "dotm"],
|
||||
InputFormat.PPTX: ["pptx", "potx", "ppsx", "pptm", "potm", "ppsm"],
|
||||
InputFormat.PDF: ["pdf"],
|
||||
@@ -97,9 +95,10 @@ FormatToExtensions: Dict[InputFormat, List[str]] = {
|
||||
InputFormat.METS_GBS: ["tar.gz"],
|
||||
InputFormat.JSON_DOCLING: ["json"],
|
||||
InputFormat.AUDIO: ["wav", "mp3"],
|
||||
InputFormat.VTT: ["vtt"],
|
||||
}
|
||||
|
||||
FormatToMimeType: Dict[InputFormat, List[str]] = {
|
||||
FormatToMimeType: dict[InputFormat, list[str]] = {
|
||||
InputFormat.DOCX: [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
|
||||
@@ -130,6 +129,7 @@ FormatToMimeType: Dict[InputFormat, List[str]] = {
|
||||
InputFormat.METS_GBS: ["application/mets+xml"],
|
||||
InputFormat.JSON_DOCLING: ["application/json"],
|
||||
InputFormat.AUDIO: ["audio/x-wav", "audio/mpeg", "audio/wav", "audio/mp3"],
|
||||
InputFormat.VTT: ["text/vtt"],
|
||||
}
|
||||
|
||||
MimeTypeToFormat: dict[str, list[InputFormat]] = {
|
||||
@@ -162,8 +162,8 @@ class Cluster(BaseModel):
|
||||
label: DocItemLabel
|
||||
bbox: BoundingBox
|
||||
confidence: float = 1.0
|
||||
cells: List[TextCell] = []
|
||||
children: List["Cluster"] = [] # Add child cluster support
|
||||
cells: list[TextCell] = []
|
||||
children: list["Cluster"] = [] # Add child cluster support
|
||||
|
||||
@field_serializer("confidence")
|
||||
def _serialize(self, value: float, info: FieldSerializationInfo) -> float:
|
||||
@@ -179,7 +179,7 @@ class BasePageElement(BaseModel):
|
||||
|
||||
|
||||
class LayoutPrediction(BaseModel):
|
||||
clusters: List[Cluster] = []
|
||||
clusters: list[Cluster] = []
|
||||
|
||||
|
||||
class VlmPredictionToken(BaseModel):
|
||||
@@ -201,14 +201,14 @@ class ContainerElement(
|
||||
|
||||
|
||||
class Table(BasePageElement):
|
||||
otsl_seq: List[str]
|
||||
otsl_seq: list[str]
|
||||
num_rows: int = 0
|
||||
num_cols: int = 0
|
||||
table_cells: List[TableCell]
|
||||
table_cells: list[TableCell]
|
||||
|
||||
|
||||
class TableStructurePrediction(BaseModel):
|
||||
table_map: Dict[int, Table] = {}
|
||||
table_map: dict[int, Table] = {}
|
||||
|
||||
|
||||
class TextElement(BasePageElement):
|
||||
@@ -216,7 +216,7 @@ class TextElement(BasePageElement):
|
||||
|
||||
|
||||
class FigureElement(BasePageElement):
|
||||
annotations: List[PictureDataType] = []
|
||||
annotations: list[PictureDataType] = []
|
||||
provenance: Optional[str] = None
|
||||
predicted_class: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
@@ -234,12 +234,12 @@ class FigureElement(BasePageElement):
|
||||
|
||||
class FigureClassificationPrediction(BaseModel):
|
||||
figure_count: int = 0
|
||||
figure_map: Dict[int, FigureElement] = {}
|
||||
figure_map: dict[int, FigureElement] = {}
|
||||
|
||||
|
||||
class EquationPrediction(BaseModel):
|
||||
equation_count: int = 0
|
||||
equation_map: Dict[int, TextElement] = {}
|
||||
equation_map: dict[int, TextElement] = {}
|
||||
|
||||
|
||||
class PagePredictions(BaseModel):
|
||||
@@ -254,9 +254,9 @@ PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
||||
|
||||
|
||||
class AssembledUnit(BaseModel):
|
||||
elements: List[PageElement] = []
|
||||
body: List[PageElement] = []
|
||||
headers: List[PageElement] = []
|
||||
elements: list[PageElement] = []
|
||||
body: list[PageElement] = []
|
||||
headers: list[PageElement] = []
|
||||
|
||||
|
||||
class ItemAndImageEnrichmentElement(BaseModel):
|
||||
@@ -280,12 +280,12 @@ class Page(BaseModel):
|
||||
None # Internal PDF backend. By default it is cleared during assembling.
|
||||
)
|
||||
_default_image_scale: float = 1.0 # Default image scale for external usage.
|
||||
_image_cache: Dict[
|
||||
_image_cache: dict[
|
||||
float, Image
|
||||
] = {} # Cache of images in different scales. By default it is cleared during assembling.
|
||||
|
||||
@property
|
||||
def cells(self) -> List[TextCell]:
|
||||
def cells(self) -> list[TextCell]:
|
||||
"""Return text cells as a read-only view of parsed_page.textline_cells."""
|
||||
if self.parsed_page is not None:
|
||||
return self.parsed_page.textline_cells
|
||||
@@ -354,7 +354,7 @@ class OpenAiApiResponse(BaseModel):
|
||||
|
||||
id: str
|
||||
model: Optional[str] = None # returned by openai
|
||||
choices: List[OpenAiResponseChoice]
|
||||
choices: list[OpenAiResponseChoice]
|
||||
created: int
|
||||
usage: OpenAiResponseUsage
|
||||
|
||||
@@ -430,7 +430,7 @@ class PageConfidenceScores(BaseModel):
|
||||
|
||||
|
||||
class ConfidenceReport(PageConfidenceScores):
|
||||
pages: Dict[int, PageConfidenceScores] = Field(
|
||||
pages: dict[int, PageConfidenceScores] = Field(
|
||||
default_factory=lambda: defaultdict(PageConfidenceScores)
|
||||
)
|
||||
|
||||
|
||||
@@ -394,6 +394,8 @@ class _DocumentConversionInput(BaseModel):
|
||||
mime = FormatToMimeType[InputFormat.PPTX][0]
|
||||
elif ext in FormatToExtensions[InputFormat.XLSX]:
|
||||
mime = FormatToMimeType[InputFormat.XLSX][0]
|
||||
elif ext in FormatToExtensions[InputFormat.VTT]:
|
||||
mime = FormatToMimeType[InputFormat.VTT][0]
|
||||
|
||||
return mime
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from docling.backend.msexcel_backend import MsExcelDocumentBackend
|
||||
from docling.backend.mspowerpoint_backend import MsPowerpointDocumentBackend
|
||||
from docling.backend.msword_backend import MsWordDocumentBackend
|
||||
from docling.backend.noop_backend import NoOpBackend
|
||||
from docling.backend.webvtt_backend import WebVTTDocumentBackend
|
||||
from docling.backend.xml.jats_backend import JatsDocumentBackend
|
||||
from docling.backend.xml.uspto_backend import PatentUsptoDocumentBackend
|
||||
from docling.datamodel.base_models import (
|
||||
@@ -170,6 +171,9 @@ def _get_default_option(format: InputFormat) -> FormatOption:
|
||||
pipeline_cls=SimplePipeline, backend=DoclingJSONBackend
|
||||
),
|
||||
InputFormat.AUDIO: FormatOption(pipeline_cls=AsrPipeline, backend=NoOpBackend),
|
||||
InputFormat.VTT: FormatOption(
|
||||
pipeline_cls=SimplePipeline, backend=WebVTTDocumentBackend
|
||||
),
|
||||
}
|
||||
if (options := format_to_default_options.get(format)) is not None:
|
||||
return options
|
||||
|
||||
Reference in New Issue
Block a user