diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index dc00a5cf..052b5621 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -59,9 +59,14 @@ class TableFormerMode(str, Enum): ACCURATE = "accurate" -class TableStructureOptions(BaseModel): +class BaseTableStructureOptions(BaseOptions): + """Base options for table structure models.""" + + +class TableStructureOptions(BaseTableStructureOptions): """Options for the table structure.""" + kind: ClassVar[str] = "docling_tableformer" do_cell_matching: bool = ( True # True: Matches predictions back to PDF cells. Can break table output if PDF cells @@ -308,19 +313,25 @@ class VlmPipelineOptions(PaginatedPipelineOptions): ) -class LayoutOptions(BaseModel): - """Options for layout processing.""" +class BaseLayoutOptions(BaseOptions): + """Base options for layout models.""" - create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells keep_empty_clusters: bool = ( False # Whether to keep clusters that contain no text cells ) - model_spec: LayoutModelConfig = DOCLING_LAYOUT_HERON skip_cell_assignment: bool = ( False # Skip cell-to-cluster assignment for VLM-only processing ) +class LayoutOptions(BaseLayoutOptions): + """Options for layout processing.""" + + kind: ClassVar[str] = "docling_layout_default" + create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells + model_spec: LayoutModelConfig = DOCLING_LAYOUT_HERON + + class AsrPipelineOptions(PipelineOptions): asr_options: Union[InlineAsrOptions] = asr_model_specs.WHISPER_TINY diff --git a/docling/models/base_layout_model.py b/docling/models/base_layout_model.py new file mode 100644 index 00000000..5e430474 --- /dev/null +++ b/docling/models/base_layout_model.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence +from typing import Type + +from docling.datamodel.base_models import LayoutPrediction, Page +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import BaseLayoutOptions +from docling.models.base_model import BaseModelWithOptions, BasePageModel + + +class BaseLayoutModel(BasePageModel, BaseModelWithOptions, ABC): + """Shared interface for layout models.""" + + @classmethod + @abstractmethod + def get_options_type(cls) -> Type[BaseLayoutOptions]: + """Return the options type supported by this layout model.""" + + @abstractmethod + def predict_layout( + self, + conv_res: ConversionResult, + pages: Sequence[Page], + ) -> Sequence[LayoutPrediction]: + """Produce layout predictions for the provided pages.""" + + def __call__( + self, + conv_res: ConversionResult, + page_batch: Iterable[Page], + ) -> Iterable[Page]: + pages = list(page_batch) + predictions = self.predict_layout(conv_res, pages) + + for page, prediction in zip(pages, predictions): + page.predictions.layout = prediction + yield page diff --git a/docling/models/base_table_model.py b/docling/models/base_table_model.py new file mode 100644 index 00000000..9c379fc9 --- /dev/null +++ b/docling/models/base_table_model.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence +from typing import Type + +from docling.datamodel.base_models import Page, TableStructurePrediction +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import BaseTableStructureOptions +from docling.models.base_model import BaseModelWithOptions, BasePageModel + + +class BaseTableStructureModel(BasePageModel, BaseModelWithOptions, ABC): + """Shared interface for table structure models.""" + + enabled: bool + + @classmethod + @abstractmethod + def get_options_type(cls) -> Type[BaseTableStructureOptions]: + """Return the options type supported by this table model.""" + + @abstractmethod + def predict_tables( + self, + conv_res: ConversionResult, + pages: Sequence[Page], + ) -> Sequence[TableStructurePrediction]: + """Produce table structure predictions for the provided pages.""" + + def __call__( + self, + conv_res: ConversionResult, + page_batch: Iterable[Page], + ) -> Iterable[Page]: + if not getattr(self, "enabled", True): + yield from page_batch + return + + pages = list(page_batch) + predictions = self.predict_tables(conv_res, pages) + + for page, prediction in zip(pages, predictions): + page.predictions.tablestructure = prediction + yield page diff --git a/docling/models/factories/__init__.py b/docling/models/factories/__init__.py index a6adb3f2..4cab9f3e 100644 --- a/docling/models/factories/__init__.py +++ b/docling/models/factories/__init__.py @@ -1,10 +1,12 @@ import logging from functools import lru_cache +from docling.models.factories.layout_factory import LayoutFactory from docling.models.factories.ocr_factory import OcrFactory from docling.models.factories.picture_description_factory import ( PictureDescriptionFactory, ) +from docling.models.factories.table_factory import TableStructureFactory logger = logging.getLogger(__name__) @@ -25,3 +27,21 @@ def get_picture_description_factory( factory.load_from_plugins(allow_external_plugins=allow_external_plugins) logger.info("Registered picture descriptions: %r", factory.registered_kind) return factory + + +@lru_cache +def get_layout_factory(allow_external_plugins: bool = False) -> LayoutFactory: + factory = LayoutFactory() + factory.load_from_plugins(allow_external_plugins=allow_external_plugins) + logger.info("Registered layout engines: %r", factory.registered_kind) + return factory + + +@lru_cache +def get_table_structure_factory( + allow_external_plugins: bool = False, +) -> TableStructureFactory: + factory = TableStructureFactory() + factory.load_from_plugins(allow_external_plugins=allow_external_plugins) + logger.info("Registered table structure engines: %r", factory.registered_kind) + return factory diff --git a/docling/models/factories/layout_factory.py b/docling/models/factories/layout_factory.py new file mode 100644 index 00000000..7390c077 --- /dev/null +++ b/docling/models/factories/layout_factory.py @@ -0,0 +1,7 @@ +from docling.models.base_layout_model import BaseLayoutModel +from docling.models.factories.base_factory import BaseFactory + + +class LayoutFactory(BaseFactory[BaseLayoutModel]): + def __init__(self, *args, **kwargs): + super().__init__("layout_engines", *args, **kwargs) diff --git a/docling/models/factories/table_factory.py b/docling/models/factories/table_factory.py new file mode 100644 index 00000000..ccb2a07e --- /dev/null +++ b/docling/models/factories/table_factory.py @@ -0,0 +1,7 @@ +from docling.models.base_table_model import BaseTableStructureModel +from docling.models.factories.base_factory import BaseFactory + + +class TableStructureFactory(BaseFactory[BaseTableStructureModel]): + def __init__(self, *args, **kwargs): + super().__init__("table_structure_engines", *args, **kwargs) diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 06001cd9..e3e08d0f 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -1,7 +1,7 @@ import copy import logging import warnings -from collections.abc import Iterable +from collections.abc import Sequence from pathlib import Path from typing import List, Optional, Union @@ -15,7 +15,7 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.layout_model_specs import DOCLING_LAYOUT_V2, LayoutModelConfig from docling.datamodel.pipeline_options import LayoutOptions from docling.datamodel.settings import settings -from docling.models.base_model import BasePageModel +from docling.models.base_layout_model import BaseLayoutModel from docling.models.utils.hf_model_download import download_hf_model from docling.utils.accelerator_utils import decide_device from docling.utils.layout_postprocessor import LayoutPostprocessor @@ -25,7 +25,7 @@ from docling.utils.visualization import draw_clusters _log = logging.getLogger(__name__) -class LayoutModel(BasePageModel): +class LayoutModel(BaseLayoutModel): TEXT_ELEM_LABELS = [ DocItemLabel.TEXT, DocItemLabel.FOOTNOTE, @@ -86,6 +86,10 @@ class LayoutModel(BasePageModel): num_threads=accelerator_options.num_threads, ) + @classmethod + def get_options_type(cls) -> type[LayoutOptions]: + return LayoutOptions + @staticmethod def download_models( local_dir: Optional[Path] = None, @@ -145,11 +149,13 @@ class LayoutModel(BasePageModel): out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png" combined_image.save(str(out_file), format="png") - def __call__( - self, conv_res: ConversionResult, page_batch: Iterable[Page] - ) -> Iterable[Page]: - # Convert to list to allow multiple iterations - pages = list(page_batch) + def predict_layout( + self, + conv_res: ConversionResult, + pages: Sequence[Page], + ) -> Sequence[LayoutPrediction]: + # Convert to list to ensure predictable iteration + pages = list(pages) # Separate valid and invalid pages valid_pages = [] @@ -167,12 +173,6 @@ class LayoutModel(BasePageModel): valid_pages.append(page) valid_page_images.append(page_image) - _log.debug(f"{len(pages)=}") - if pages: - _log.debug(f"{pages[0].page_no}-{pages[-1].page_no}") - _log.debug(f"{len(valid_pages)=}") - _log.debug(f"{len(valid_page_images)=}") - # Process all valid pages with batch prediction batch_predictions = [] if valid_page_images: @@ -182,11 +182,14 @@ class LayoutModel(BasePageModel): ) # Process each page with its predictions + layout_predictions: list[LayoutPrediction] = [] valid_page_idx = 0 for page in pages: assert page._backend is not None if not page._backend.is_valid(): - yield page + existing_prediction = page.predictions.layout or LayoutPrediction() + page.predictions.layout = existing_prediction + layout_predictions.append(existing_prediction) continue page_predictions = batch_predictions[valid_page_idx] @@ -233,11 +236,14 @@ class LayoutModel(BasePageModel): np.mean([c.confidence for c in processed_cells if c.from_ocr]) ) - page.predictions.layout = LayoutPrediction(clusters=processed_clusters) + prediction = LayoutPrediction(clusters=processed_clusters) + page.predictions.layout = prediction if settings.debug.visualize_layout: self.draw_clusters_and_cells_side_by_side( conv_res, page, processed_clusters, mode_prefix="postprocessed" ) - yield page + layout_predictions.append(prediction) + + return layout_predictions diff --git a/docling/models/plugins/defaults.py b/docling/models/plugins/defaults.py index 352f5c75..06b87080 100644 --- a/docling/models/plugins/defaults.py +++ b/docling/models/plugins/defaults.py @@ -28,3 +28,23 @@ def picture_description(): PictureDescriptionApiModel, ] } + + +def layout_engines(): + from docling.models.layout_model import LayoutModel + + return { + "layout_engines": [ + LayoutModel, + ] + } + + +def table_structure_engines(): + from docling.models.table_structure_model import TableStructureModel + + return { + "table_structure_engines": [ + TableStructureModel, + ] + } diff --git a/docling/models/table_structure_model.py b/docling/models/table_structure_model.py index c3365f79..c71f945f 100644 --- a/docling/models/table_structure_model.py +++ b/docling/models/table_structure_model.py @@ -1,6 +1,6 @@ import copy import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from pathlib import Path from typing import Optional @@ -20,13 +20,13 @@ from docling.datamodel.pipeline_options import ( TableStructureOptions, ) from docling.datamodel.settings import settings -from docling.models.base_model import BasePageModel +from docling.models.base_table_model import BaseTableStructureModel from docling.models.utils.hf_model_download import download_hf_model from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import TimeRecorder -class TableStructureModel(BasePageModel): +class TableStructureModel(BaseTableStructureModel): _model_repo_folder = "docling-project--docling-models" _model_path = "model_artifacts/tableformer" @@ -88,6 +88,10 @@ class TableStructureModel(BasePageModel): ) self.scale = 2.0 # Scale up table input images to 144 dpi + @classmethod + def get_options_type(cls) -> type[TableStructureOptions]: + return TableStructureOptions + @staticmethod def download_models( local_dir: Optional[Path] = None, force: bool = False, progress: bool = False @@ -167,138 +171,135 @@ class TableStructureModel(BasePageModel): out_file = out_path / f"table_struct_page_{page.page_no:05}.png" image.save(str(out_file), format="png") - def __call__( - self, conv_res: ConversionResult, page_batch: Iterable[Page] - ) -> Iterable[Page]: - if not self.enabled: - yield from page_batch - return + def predict_tables( + self, + conv_res: ConversionResult, + pages: Sequence[Page], + ) -> Sequence[TableStructurePrediction]: + pages = list(pages) + predictions: list[TableStructurePrediction] = [] - for page in page_batch: + for page in pages: assert page._backend is not None if not page._backend.is_valid(): - yield page - else: - with TimeRecorder(conv_res, "table_structure"): - assert page.predictions.layout is not None - assert page.size is not None + existing_prediction = ( + page.predictions.tablestructure or TableStructurePrediction() + ) + page.predictions.tablestructure = existing_prediction + predictions.append(existing_prediction) + continue - page.predictions.tablestructure = ( - TableStructurePrediction() - ) # dummy + with TimeRecorder(conv_res, "table_structure"): + assert page.predictions.layout is not None + assert page.size is not None - in_tables = [ - ( - cluster, - [ - round(cluster.bbox.l) * self.scale, - round(cluster.bbox.t) * self.scale, - round(cluster.bbox.r) * self.scale, - round(cluster.bbox.b) * self.scale, - ], + table_prediction = TableStructurePrediction() + page.predictions.tablestructure = table_prediction + + in_tables = [ + ( + cluster, + [ + round(cluster.bbox.l) * self.scale, + round(cluster.bbox.t) * self.scale, + round(cluster.bbox.r) * self.scale, + round(cluster.bbox.b) * self.scale, + ], + ) + for cluster in page.predictions.layout.clusters + if cluster.label + in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] + ] + if not in_tables: + predictions.append(table_prediction) + continue + + page_input = { + "width": page.size.width * self.scale, + "height": page.size.height * self.scale, + "image": numpy.asarray(page.get_image(scale=self.scale)), + } + + for table_cluster, tbl_box in in_tables: + # Check if word-level cells are available from backend: + sp = page._backend.get_segmented_page() + if sp is not None: + tcells = sp.get_cells_in_bbox( + cell_unit=TextCellUnit.WORD, + bbox=table_cluster.bbox, ) - for cluster in page.predictions.layout.clusters - if cluster.label - in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] - ] - if not len(in_tables): - yield page - continue - - page_input = { - "width": page.size.width * self.scale, - "height": page.size.height * self.scale, - "image": numpy.asarray(page.get_image(scale=self.scale)), - } - - table_clusters, table_bboxes = zip(*in_tables) - - if len(table_bboxes): - for table_cluster, tbl_box in in_tables: - # Check if word-level cells are available from backend: - sp = page._backend.get_segmented_page() - if sp is not None: - tcells = sp.get_cells_in_bbox( - cell_unit=TextCellUnit.WORD, - bbox=table_cluster.bbox, - ) - if len(tcells) == 0: - # In case word-level cells yield empty - tcells = table_cluster.cells - else: - # Otherwise - we use normal (line/phrase) cells - tcells = table_cluster.cells - tokens = [] - for c in tcells: - # Only allow non empty strings (spaces) into the cells of a table - if len(c.text.strip()) > 0: - new_cell = copy.deepcopy(c) - new_cell.rect = BoundingRectangle.from_bounding_box( - new_cell.rect.to_bounding_box().scaled( - scale=self.scale - ) - ) - tokens.append( - { - "id": new_cell.index, - "text": new_cell.text, - "bbox": new_cell.rect.to_bounding_box().model_dump(), - } - ) - page_input["tokens"] = tokens - - tf_output = self.tf_predictor.multi_table_predict( - page_input, [tbl_box], do_matching=self.do_cell_matching + if len(tcells) == 0: + # In case word-level cells yield empty + tcells = table_cluster.cells + else: + # Otherwise - we use normal (line/phrase) cells + tcells = table_cluster.cells + tokens = [] + for c in tcells: + # Only allow non empty strings (spaces) into the cells of a table + if len(c.text.strip()) > 0: + new_cell = copy.deepcopy(c) + new_cell.rect = BoundingRectangle.from_bounding_box( + new_cell.rect.to_bounding_box().scaled(scale=self.scale) ) - table_out = tf_output[0] - table_cells = [] - for element in table_out["tf_responses"]: - if not self.do_cell_matching: - the_bbox = BoundingBox.model_validate( - element["bbox"] - ).scaled(1 / self.scale) - text_piece = page._backend.get_text_in_rect( - the_bbox - ) - element["bbox"]["token"] = text_piece - - tc = TableCell.model_validate(element) - if tc.bbox is not None: - tc.bbox = tc.bbox.scaled(1 / self.scale) - table_cells.append(tc) - - assert "predict_details" in table_out - - # Retrieving cols/rows, after post processing: - num_rows = table_out["predict_details"].get("num_rows", 0) - num_cols = table_out["predict_details"].get("num_cols", 0) - otsl_seq = ( - table_out["predict_details"] - .get("prediction", {}) - .get("rs_seq", []) + tokens.append( + { + "id": new_cell.index, + "text": new_cell.text, + "bbox": new_cell.rect.to_bounding_box().model_dump(), + } ) + page_input["tokens"] = tokens - tbl = Table( - otsl_seq=otsl_seq, - table_cells=table_cells, - num_rows=num_rows, - num_cols=num_cols, - id=table_cluster.id, - page_no=page.page_no, - cluster=table_cluster, - label=table_cluster.label, - ) + tf_output = self.tf_predictor.multi_table_predict( + page_input, [tbl_box], do_matching=self.do_cell_matching + ) + table_out = tf_output[0] + table_cells = [] + for element in table_out["tf_responses"]: + if not self.do_cell_matching: + the_bbox = BoundingBox.model_validate( + element["bbox"] + ).scaled(1 / self.scale) + text_piece = page._backend.get_text_in_rect(the_bbox) + element["bbox"]["token"] = text_piece - page.predictions.tablestructure.table_map[ - table_cluster.id - ] = tbl + tc = TableCell.model_validate(element) + if tc.bbox is not None: + tc.bbox = tc.bbox.scaled(1 / self.scale) + table_cells.append(tc) - # For debugging purposes: - if settings.debug.visualize_tables: - self.draw_table_and_cells( - conv_res, - page, - page.predictions.tablestructure.table_map.values(), - ) + assert "predict_details" in table_out - yield page + # Retrieving cols/rows, after post processing: + num_rows = table_out["predict_details"].get("num_rows", 0) + num_cols = table_out["predict_details"].get("num_cols", 0) + otsl_seq = ( + table_out["predict_details"] + .get("prediction", {}) + .get("rs_seq", []) + ) + + tbl = Table( + otsl_seq=otsl_seq, + table_cells=table_cells, + num_rows=num_rows, + num_cols=num_cols, + id=table_cluster.id, + page_no=page.page_no, + cluster=table_cluster, + label=table_cluster.label, + ) + + table_prediction.table_map[table_cluster.id] = tbl + + if settings.debug.visualize_tables: + self.draw_table_and_cells( + conv_res, + page, + page.predictions.tablestructure.table_map.values(), + ) + + predictions.append(table_prediction) + + return predictions diff --git a/docling/pipeline/legacy_standard_pdf_pipeline.py b/docling/pipeline/legacy_standard_pdf_pipeline.py index edcf04f4..55c2703c 100644 --- a/docling/pipeline/legacy_standard_pdf_pipeline.py +++ b/docling/pipeline/legacy_standard_pdf_pipeline.py @@ -15,15 +15,17 @@ from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions -from docling.models.factories import get_ocr_factory -from docling.models.layout_model import LayoutModel +from docling.models.factories import ( + get_layout_factory, + get_ocr_factory, + get_table_structure_factory, +) from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_preprocessing_model import ( PagePreprocessingModel, PagePreprocessingOptions, ) from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions -from docling.models.table_structure_model import TableStructureModel from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.model_downloader import download_models from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -48,6 +50,24 @@ class LegacyStandardPdfPipeline(PaginatedPipeline): ocr_model = self.get_ocr_model(artifacts_path=self.artifacts_path) + layout_factory = get_layout_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) + layout_model = layout_factory.create_instance( + options=pipeline_options.layout_options, + artifacts_path=self.artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + ) + table_factory = get_table_structure_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) + table_model = table_factory.create_instance( + options=pipeline_options.table_structure_options, + enabled=pipeline_options.do_table_structure, + artifacts_path=self.artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + ) + self.build_pipe = [ # Pre-processing PagePreprocessingModel( @@ -58,18 +78,9 @@ class LegacyStandardPdfPipeline(PaginatedPipeline): # OCR ocr_model, # Layout model - LayoutModel( - artifacts_path=self.artifacts_path, - accelerator_options=pipeline_options.accelerator_options, - options=pipeline_options.layout_options, - ), + layout_model, # Table structure model - TableStructureModel( - enabled=pipeline_options.do_table_structure, - artifacts_path=self.artifacts_path, - options=pipeline_options.table_structure_options, - accelerator_options=pipeline_options.accelerator_options, - ), + table_model, # Page assemble PageAssembleModel(options=PageAssembleOptions()), ] diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 2eaaadca..585c548c 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -41,15 +41,17 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ThreadedPdfPipelineOptions from docling.datamodel.settings import settings from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions -from docling.models.factories import get_ocr_factory -from docling.models.layout_model import LayoutModel +from docling.models.factories import ( + get_layout_factory, + get_ocr_factory, + get_table_structure_factory, +) from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_preprocessing_model import ( PagePreprocessingModel, PagePreprocessingOptions, ) from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions -from docling.models.table_structure_model import TableStructureModel from docling.pipeline.base_pipeline import ConvertPipeline from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.utils import chunkify @@ -436,15 +438,21 @@ class StandardPdfPipeline(ConvertPipeline): ) ) self.ocr_model = self._make_ocr_model(art_path) - self.layout_model = LayoutModel( + layout_factory = get_layout_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) + self.layout_model = layout_factory.create_instance( + options=self.pipeline_options.layout_options, artifacts_path=art_path, accelerator_options=self.pipeline_options.accelerator_options, - options=self.pipeline_options.layout_options, ) - self.table_model = TableStructureModel( + table_factory = get_table_structure_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) + self.table_model = table_factory.create_instance( + options=self.pipeline_options.table_structure_options, enabled=self.pipeline_options.do_table_structure, artifacts_path=art_path, - options=self.pipeline_options.table_structure_options, accelerator_options=self.pipeline_options.accelerator_options, ) self.assemble_model = PageAssembleModel(options=PageAssembleOptions())