diff --git a/docling/backend/mets_gbs_backend.py b/docling/backend/mets_gbs_backend.py new file mode 100644 index 00000000..f2a7d2b5 --- /dev/null +++ b/docling/backend/mets_gbs_backend.py @@ -0,0 +1,387 @@ +"""Backend for GBS Google Books schema.""" + +import logging +import tarfile +from collections.abc import Iterable +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union + +from docling_core.types.doc import BoundingBox, CoordOrigin, Size +from docling_core.types.doc.page import ( + BoundingRectangle, + PdfPageBoundaryType, + PdfPageGeometry, + SegmentedPdfPage, + TextCell, +) +from lxml import etree +from PIL import Image +from PIL.Image import Image as PILImage + +from docling.backend.abstract_backend import PaginatedDocumentBackend +from docling.backend.pdf_backend import PdfPageBackend +from docling.datamodel.base_models import InputFormat + +if TYPE_CHECKING: + from docling.datamodel.document import InputDocument + +_log = logging.getLogger(__name__) + + +def _get_pdf_page_geometry( + size: Size, +) -> PdfPageGeometry: + boundary_type: PdfPageBoundaryType = PdfPageBoundaryType.CROP_BOX + + bbox_tuple = (0, 0, size.width, size.height) + bbox = BoundingBox.from_tuple(bbox_tuple, CoordOrigin.TOPLEFT) + + return PdfPageGeometry( + angle=0.0, + rect=BoundingRectangle.from_bounding_box(bbox), + boundary_type=boundary_type, + art_bbox=bbox, + bleed_bbox=bbox, + crop_bbox=bbox, + media_bbox=bbox, + trim_bbox=bbox, + ) + + +class MetsGbsPageBackend(PdfPageBackend): + def __init__(self, parsed_page: SegmentedPdfPage, page_im: PILImage): + self._im = page_im + self._dpage = parsed_page + self.valid = parsed_page is not None + + def is_valid(self) -> bool: + return self.valid + + def get_text_in_rect(self, bbox: BoundingBox) -> str: + # Find intersecting cells on the page + text_piece = "" + page_size = self.get_size() + + scale = ( + 1 # FIX - Replace with param in get_text_in_rect across backends (optional) + ) + + for i, cell in enumerate(self._dpage.textline_cells): + cell_bbox = ( + cell.rect.to_bounding_box() + .to_top_left_origin(page_height=page_size.height) + .scaled(scale) + ) + + overlap_frac = cell_bbox.intersection_over_self(bbox) + + if overlap_frac > 0.5: + if len(text_piece) > 0: + text_piece += " " + text_piece += cell.text + + return text_piece + + def get_segmented_page(self) -> Optional[SegmentedPdfPage]: + return self._dpage + + def get_text_cells(self) -> Iterable[TextCell]: + return self._dpage.textline_cells + + def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]: + AREA_THRESHOLD = 0 # 32 * 32 + + images = self._dpage.bitmap_resources + + for img in images: + cropbox = img.rect.to_bounding_box().to_top_left_origin( + self.get_size().height + ) + + if cropbox.area() > AREA_THRESHOLD: + cropbox = cropbox.scaled(scale=scale) + + yield cropbox + + def get_page_image( + self, scale: float = 1, cropbox: Optional[BoundingBox] = None + ) -> Image.Image: + page_size = self.get_size() + assert ( + page_size.width == self._im.size[0] and page_size.height == self._im.size[1] + ) + + if not cropbox: + cropbox = BoundingBox( + l=0, + r=page_size.width, + t=0, + b=page_size.height, + coord_origin=CoordOrigin.TOPLEFT, + ) + + image = self._im.resize( + size=(round(page_size.width * scale), round(page_size.height * scale)) + ).crop(cropbox.scaled(scale=scale).as_tuple()) + return image + + def get_size(self) -> Size: + return Size( + width=self._dpage.dimension.width, height=self._dpage.dimension.height + ) + + def unload(self): + self._ppage = None + self._dpage = None + + +class _UseType(str, Enum): + IMAGE = "image" + OCR = "OCR" + COORD_OCR = "coordOCR" + + +@dataclass +class _FileInfo: + file_id: str + mimetype: str + path: str + use: _UseType + + +@dataclass +class _PageFiles: + image: Optional[_FileInfo] = None + ocr: Optional[_FileInfo] = None + coordOCR: Optional[_FileInfo] = None + + +def _extract_rect(title_str: str) -> Optional[BoundingRectangle]: + """ + Extracts bbox from title string like 'bbox 279 177 306 214;x_wconf 97' + """ + parts = title_str.split(";") + for part in parts: + part = part.strip() + if part.startswith("bbox "): + try: + coords = part.split()[1:] + rect = BoundingRectangle.from_bounding_box( + bbox=BoundingBox.from_tuple( + tuple(map(int, coords)), origin=CoordOrigin.TOPLEFT + ) + ) + return rect + except Exception: + return None + return None + + +def _extract_confidence(title_str) -> float: + """Extracts x_wconf (OCR confidence) value from title string.""" + for part in title_str.split(";"): + part = part.strip() + if part.startswith("x_wconf"): + try: + return float(part.split()[1]) / 100.0 + except Exception: + return 1 + return 1 + + +class MetsGbsDocumentBackend(PaginatedDocumentBackend): + def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]): + super().__init__(in_doc, path_or_stream) + + self._tar: tarfile.TarFile = ( + tarfile.open(name=self.path_or_stream, mode="r:gz") + if isinstance(self.path_or_stream, Path) + else tarfile.open(fileobj=self.path_or_stream, mode="r:gz") + ) + self.root_mets: Optional[etree._Element] = None + self.page_map: Dict[int, _PageFiles] = {} + + for member in self._tar.getmembers(): + if member.name.endswith(".xml"): + file = self._tar.extractfile(member) + if file is not None: + content = file.read() + self.root_mets = self._validate_mets_xml(content) + if self.root_mets is not None: + break + + if self.root_mets is None: + raise RuntimeError( + f"METS GBS backend could not load document {self.document_hash}." + ) + + ns = { + "mets": "http://www.loc.gov/METS/", + "xlink": "http://www.w3.org/1999/xlink", + "xsi": "http://www.w3.org/2001/XMLSchema-instance", + "gbs": "http://books.google.com/gbs", + "premis": "info:lc/xmlns/premis-v2", + "marc": "http://www.loc.gov/MARC21/slim", + } + + file_info_by_id: Dict[str, _FileInfo] = {} + + for filegrp in self.root_mets.xpath(".//mets:fileGrp", namespaces=ns): + use_raw = filegrp.get("USE") + try: + use = _UseType(use_raw) + except ValueError: + continue # Ignore unknown USE types + + for file_elem in filegrp.xpath("./mets:file", namespaces=ns): + file_id = file_elem.get("ID") + mimetype = file_elem.get("MIMETYPE") + flocat_elem = file_elem.find("mets:FLocat", namespaces=ns) + href = ( + flocat_elem.get("{http://www.w3.org/1999/xlink}href") + if flocat_elem is not None + else None + ) + if href is None: + continue + + file_info_by_id[file_id] = _FileInfo( + file_id=file_id, mimetype=mimetype, path=href, use=use + ) + + USE_TO_ATTR = { + _UseType.IMAGE: "image", + _UseType.OCR: "ocr", + _UseType.COORD_OCR: "coordOCR", + } + + for div in self.root_mets.xpath('.//mets:div[@TYPE="page"]', namespaces=ns): + order_str = div.get("ORDER") + if not order_str: + continue + try: + page_no = int(order_str) - 1 # make 0-index pages + except ValueError: + continue + + page_files = _PageFiles() + + for fptr in div.xpath("./mets:fptr", namespaces=ns): + file_id = fptr.get("FILEID") + file_info = file_info_by_id.get(file_id) + + if file_info: + attr = USE_TO_ATTR.get(file_info.use) + if attr: + setattr(page_files, attr, file_info) + + self.page_map[page_no] = page_files + + def _validate_mets_xml(self, xml_string) -> Optional[etree._Element]: + root: etree._Element = etree.fromstring(xml_string) + if ( + root.tag == "{http://www.loc.gov/METS/}mets" + and root.get("PROFILE") == "gbs" + ): + return root + + _log.warning(f"The root element is not with PROFILE='gbs': {root}") + return None + + def _parse_page(self, page_no: int) -> Tuple[SegmentedPdfPage, PILImage]: + # TODO: use better fallbacks... + image_info = self.page_map[page_no].image + assert image_info is not None + ocr_info = self.page_map[page_no].coordOCR + assert ocr_info is not None + + image_file = self._tar.extractfile(image_info.path) + assert image_file is not None + buf = BytesIO(image_file.read()) + im: PILImage = Image.open(buf) + ocr_file = self._tar.extractfile(ocr_info.path) + assert ocr_file is not None + ocr_content = ocr_file.read() + ocr_root: etree._Element = etree.fromstring(ocr_content) + + line_cells: List[TextCell] = [] + word_cells: List[TextCell] = [] + + ns = {"x": "http://www.w3.org/1999/xhtml"} + page_div = ocr_root.xpath("//x:div[@class='ocr_page']", namespaces=ns) + + size = Size(width=im.size[0], height=im.size[1]) + if page_div: + title = page_div[0].attrib.get("title", "") + rect = _extract_rect(title) + if rect: + size = Size(width=rect.width, height=rect.height) + else: + _log.error(f"Could not find ocr_page for page {page_no}") + + im = im.resize(size=(round(size.width), round(size.height))) + im = im.convert("RGB") + + # Extract all ocrx_word spans + for word in ocr_root.xpath("//x:span[@class='ocrx_word']", namespaces=ns): + text = "".join(word.itertext()).strip() + title = word.attrib.get("title", "") + rect = _extract_rect(title) + conf = _extract_confidence(title) + if rect: + word_cells.append( + TextCell( + text=text, orig=text, rect=rect, from_ocr=True, confidence=conf + ) + ) + + # Extract all ocr_line spans + # line: etree._Element + for line in ocr_root.xpath("//x:span[@class='ocr_line']", namespaces=ns): + text = "".join(line.itertext()).strip() + title = line.attrib.get("title", "") + rect = _extract_rect(title) + conf = _extract_confidence(title) + if rect: + line_cells.append( + TextCell( + text=text, orig=text, rect=rect, from_ocr=True, confidence=conf + ) + ) + + page = SegmentedPdfPage( + dimension=_get_pdf_page_geometry(size), + textline_cells=line_cells, + char_cells=[], + word_cells=word_cells, + has_textlines=True, + has_words=True, + has_chars=False, + ) + return page, im + + def page_count(self) -> int: + return len(self.page_map) + + def load_page(self, page_no: int) -> MetsGbsPageBackend: + # TODO: is this thread-safe? + page, im = self._parse_page(page_no) + return MetsGbsPageBackend(parsed_page=page, page_im=im) + + def is_valid(self) -> bool: + return self.root_mets is not None and self.page_count() > 0 + + @classmethod + def supported_formats(cls) -> Set[InputFormat]: + return {InputFormat.XML_METS_GBS} + + @classmethod + def supports_pagination(cls) -> bool: + return True + + def unload(self): + super().unload() + self._tar.close() diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index c753ac60..6825e125 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -56,6 +56,7 @@ class InputFormat(str, Enum): XLSX = "xlsx" XML_USPTO = "xml_uspto" XML_JATS = "xml_jats" + XML_METS_GBS = "xml_mets_gbs" JSON_DOCLING = "json_docling" AUDIO = "audio" @@ -81,6 +82,7 @@ FormatToExtensions: Dict[InputFormat, List[str]] = { InputFormat.CSV: ["csv"], InputFormat.XLSX: ["xlsx", "xlsm"], InputFormat.XML_USPTO: ["xml", "txt"], + InputFormat.XML_METS_GBS: ["tar.gz"], InputFormat.JSON_DOCLING: ["json"], InputFormat.AUDIO: ["wav", "mp3"], } @@ -113,6 +115,7 @@ FormatToMimeType: Dict[InputFormat, List[str]] = { "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ], InputFormat.XML_USPTO: ["application/xml", "text/plain"], + InputFormat.XML_METS_GBS: ["application/mets+xml"], InputFormat.JSON_DOCLING: ["application/json"], InputFormat.AUDIO: ["audio/x-wav", "audio/mpeg", "audio/wav", "audio/mp3"], } diff --git a/docling/datamodel/document.py b/docling/datamodel/document.py index 9f5cf82c..a9a3c9b1 100644 --- a/docling/datamodel/document.py +++ b/docling/datamodel/document.py @@ -1,6 +1,7 @@ import csv import logging import re +import tarfile from collections.abc import Iterable from enum import Enum from io import BytesIO @@ -314,21 +315,25 @@ class _DocumentConversionInput(BaseModel): elif objname.endswith(".pptx"): mime = "application/vnd.openxmlformats-officedocument.presentationml.presentation" + if mime is not None and mime.lower() == "application/gzip": + if detected_mime := _DocumentConversionInput._detect_mets_gbs(obj): + mime = detected_mime + mime = mime or _DocumentConversionInput._detect_html_xhtml(content) mime = mime or _DocumentConversionInput._detect_csv(content) mime = mime or "text/plain" formats = MimeTypeToFormat.get(mime, []) _log.info(f"detected formats: {formats}") - if formats: - if len(formats) == 1 and mime not in ("text/plain"): - return formats[0] - else: # ambiguity in formats - return _DocumentConversionInput._guess_from_content( - content, mime, formats - ) - else: - return None + input_format: Optional[InputFormat] = None + if len(formats) == 1: + input_format = formats[0] + + if content: + input_format = _DocumentConversionInput._guess_from_content( + content, mime, formats + ) + return input_format @staticmethod def _guess_from_content( @@ -337,6 +342,9 @@ class _DocumentConversionInput(BaseModel): """Guess the input format of a document by checking part of its content.""" input_format: Optional[InputFormat] = None + if len(formats) == 1: + input_format = formats[0] + if mime == "application/xml": content_str = content.decode("utf-8") match_doctype = re.search(r"]+>", content_str) @@ -457,3 +465,24 @@ class _DocumentConversionInput(BaseModel): return None return None + + @staticmethod + def _detect_mets_gbs( + obj: Union[Path, DocumentStream], + ) -> Optional[Literal["application/mets+xml"]]: + content = obj if isinstance(obj, Path) else obj.stream + tar: tarfile.TarFile + member: tarfile.TarInfo + with tarfile.open( + name=content if isinstance(content, Path) else None, + fileobj=content if isinstance(content, BytesIO) else None, + mode="r:gz", + ) as tar: + for member in tar.getmembers(): + if member.name.endswith(".xml"): + file = tar.extractfile(member) + if file is not None: + content_str = file.read().decode() + if "http://www.loc.gov/METS/" in content_str: + return "application/mets+xml" + return None diff --git a/docling/document_converter.py b/docling/document_converter.py index f3bcb89e..fea14f38 100644 --- a/docling/document_converter.py +++ b/docling/document_converter.py @@ -17,6 +17,7 @@ from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBacke from docling.backend.html_backend import HTMLDocumentBackend from docling.backend.json.docling_json_backend import DoclingJSONBackend from docling.backend.md_backend import MarkdownDocumentBackend +from docling.backend.mets_gbs_backend import MetsGbsDocumentBackend from docling.backend.msexcel_backend import MsExcelDocumentBackend from docling.backend.mspowerpoint_backend import MsPowerpointDocumentBackend from docling.backend.msword_backend import MsWordDocumentBackend @@ -156,6 +157,9 @@ def _get_default_option(format: InputFormat) -> FormatOption: InputFormat.XML_JATS: FormatOption( pipeline_cls=SimplePipeline, backend=JatsDocumentBackend ), + InputFormat.XML_METS_GBS: FormatOption( + pipeline_cls=StandardPdfPipeline, backend=MetsGbsDocumentBackend + ), InputFormat.IMAGE: FormatOption( pipeline_cls=StandardPdfPipeline, backend=DoclingParseV4DocumentBackend ), diff --git a/docling/pipeline/base_pipeline.py b/docling/pipeline/base_pipeline.py index 6944a355..5c289ad9 100644 --- a/docling/pipeline/base_pipeline.py +++ b/docling/pipeline/base_pipeline.py @@ -8,7 +8,10 @@ from typing import Any, Callable, List from docling_core.types.doc import NodeItem -from docling.backend.abstract_backend import AbstractDocumentBackend +from docling.backend.abstract_backend import ( + AbstractDocumentBackend, + PaginatedDocumentBackend, +) from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import ( ConversionStatus, @@ -126,10 +129,10 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name. yield from page_batch def _build_document(self, conv_res: ConversionResult) -> ConversionResult: - if not isinstance(conv_res.input._backend, PdfDocumentBackend): + if not isinstance(conv_res.input._backend, PaginatedDocumentBackend): raise RuntimeError( - f"The selected backend {type(conv_res.input._backend).__name__} for {conv_res.input.file} is not a PDF backend. " - f"Can not convert this with a PDF pipeline. " + f"The selected backend {type(conv_res.input._backend).__name__} for {conv_res.input.file} is not a paginated backend. " + f"Can not convert this with a paginated PDF pipeline. " f"Please check your format configuration on DocumentConverter." ) # conv_res.status = ConversionStatus.FAILURE diff --git a/tests/test_backend_mets_gbs.py b/tests/test_backend_mets_gbs.py new file mode 100644 index 00000000..c8be4327 --- /dev/null +++ b/tests/test_backend_mets_gbs.py @@ -0,0 +1,89 @@ +from pathlib import Path + +import pytest + +from docling.backend.mets_gbs_backend import MetsGbsDocumentBackend, MetsGbsPageBackend +from docling.datamodel.base_models import BoundingBox, InputFormat +from docling.datamodel.document import InputDocument + + +@pytest.fixture +def test_doc_path(): + return Path("/Users/dol/Downloads/32044009881525.tar.gz") + + +def _get_backend(pdf_doc): + in_doc = InputDocument( + path_or_stream=pdf_doc, + format=InputFormat.PDF, + backend=MetsGbsDocumentBackend, + ) + + doc_backend = in_doc._backend + return doc_backend + + +def test_process_pages(test_doc_path): + doc_backend: MetsGbsDocumentBackend = _get_backend(test_doc_path) + + for page_index in range(doc_backend.page_count()): + page_backend: MetsGbsPageBackend = doc_backend.load_page(page_index) + list(page_backend.get_text_cells()) + + # Clean up page backend after each iteration + page_backend.unload() + + # Explicitly clean up document backend to prevent race conditions in CI + doc_backend.unload() + + +def test_get_text_from_rect(test_doc_path): + doc_backend: MetsGbsDocumentBackend = _get_backend(test_doc_path) + page_backend: MetsGbsPageBackend = doc_backend.load_page(9) + + # Get the title text of the DocLayNet paper + textpiece = page_backend.get_text_in_rect( + bbox=BoundingBox(l=275, t=263, r=1388, b=311) + ) + ref = "recently become prevalent that he who speaks" + + assert textpiece.strip() == ref + + # Explicitly clean up resources + page_backend.unload() + doc_backend.unload() + + +def test_crop_page_image(test_doc_path): + doc_backend: MetsGbsDocumentBackend = _get_backend(test_doc_path) + page_backend: MetsGbsPageBackend = doc_backend.load_page(9) + + page_backend.get_page_image( + scale=2, cropbox=BoundingBox(l=270, t=587, r=1385, b=1995) + ) + # im.show() + + # Explicitly clean up resources + page_backend.unload() + doc_backend.unload() + + +def test_crop_page_image_jp2(test_doc_path): + doc_backend: MetsGbsDocumentBackend = _get_backend(test_doc_path) + page_backend: MetsGbsPageBackend = doc_backend.load_page(1) + + page_backend.get_page_image(scale=2, cropbox=BoundingBox(l=160, t=29, r=732, b=173)) + # im.show() + + # Explicitly clean up resources + page_backend.unload() + doc_backend.unload() + + +def test_num_pages(test_doc_path): + doc_backend: MetsGbsDocumentBackend = _get_backend(test_doc_path) + assert doc_backend.is_valid() + assert doc_backend.page_count() == 276 + + # Explicitly clean up resources to prevent race conditions in CI + doc_backend.unload()