use options objects

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2024-10-11 12:58:59 +02:00
parent cc9bcc424d
commit c6e1471e02
7 changed files with 78 additions and 67 deletions

View File

@ -72,19 +72,4 @@ class PdfPipelineOptions(PipelineOptions):
Field(EasyOcrOptions(), discriminator="kind") Field(EasyOcrOptions(), discriminator="kind")
) )
keep_page_images: Annotated[
bool,
Field(
deprecated="`keep_page_images` is depreacted, set the value of `images_scale` instead"
),
] = False # False: page images are removed in the assemble step
images_scale: Optional[float] = None # if set, the scale for generated images images_scale: Optional[float] = None # if set, the scale for generated images
@model_validator(mode="after")
def set_page_images_from_deprecated(self) -> "PdfPipelineOptions":
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
default_scale = 1.0
if self.keep_page_images and self.images_scale is None:
self.images_scale = default_scale
return self

View File

@ -2,6 +2,7 @@ import copy
import logging import logging
import random import random
import time import time
from pathlib import Path
from typing import Iterable, List from typing import Iterable, List
from docling_core.types.experimental import CoordOrigin from docling_core.types.experimental import CoordOrigin
@ -43,11 +44,8 @@ class LayoutModel(AbstractPageModel):
FIGURE_LABEL = DocItemLabel.PICTURE FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA FORMULA_LABEL = DocItemLabel.FORMULA
def __init__(self, config): def __init__(self, artifacts_path: Path):
self.config = config self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
self.layout_predictor = LayoutPredictor(
config["artifacts_path"]
) # TODO temporary
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height): def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
MIN_INTERSECTION = 0.2 MIN_INTERSECTION = 0.2

View File

@ -2,6 +2,8 @@ import logging
import re import re
from typing import Iterable, List from typing import Iterable, List
from pydantic import BaseModel
from docling.datamodel.base_models import ( from docling.datamodel.base_models import (
AssembledUnit, AssembledUnit,
FigureElement, FigureElement,
@ -16,9 +18,13 @@ from docling.models.layout_model import LayoutModel
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class PageAssembleOptions(BaseModel):
keep_images: bool = False
class PageAssembleModel(AbstractPageModel): class PageAssembleModel(AbstractPageModel):
def __init__(self, config): def __init__(self, options: PageAssembleOptions):
self.config = config self.options = options
def sanitize_text(self, lines): def sanitize_text(self, lines):
if len(lines) <= 1: if len(lines) <= 1:
@ -147,7 +153,7 @@ class PageAssembleModel(AbstractPageModel):
) )
# Remove page images (can be disabled) # Remove page images (can be disabled)
if self.config["images_scale"] is None: if not self.options.keep_images:
page._image_cache = {} page._image_cache = {}
# Unload backend # Unload backend

View File

@ -1,14 +1,19 @@
from typing import Iterable from typing import Iterable
from PIL import ImageDraw from PIL import ImageDraw
from pydantic import BaseModel
from docling.datamodel.base_models import Page from docling.datamodel.base_models import Page
from docling.models.abstract_model import AbstractPageModel from docling.models.abstract_model import AbstractPageModel
class PagePreprocessingOptions(BaseModel):
images_scale: float
class PagePreprocessingModel(AbstractPageModel): class PagePreprocessingModel(AbstractPageModel):
def __init__(self, config): def __init__(self, options: PagePreprocessingOptions):
self.config = config self.options = options
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch: for page in page_batch:
@ -23,7 +28,7 @@ class PagePreprocessingModel(AbstractPageModel):
scale=1.0 scale=1.0
) # puts the page image on the image cache at default scale ) # puts the page image on the image cache at default scale
images_scale = self.config["images_scale"] images_scale = self.options.images_scale
# user requested scales # user requested scales
if images_scale is not None: if images_scale is not None:
page._default_image_scale = images_scale page._default_image_scale = images_scale

View File

@ -10,19 +10,21 @@ from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredic
from PIL import ImageDraw from PIL import ImageDraw
from docling.datamodel.base_models import Page, Table, TableStructurePrediction from docling.datamodel.base_models import Page, Table, TableStructurePrediction
from docling.datamodel.pipeline_options import TableFormerMode from docling.datamodel.pipeline_options import TableFormerMode, TableStructureOptions
from docling.models.abstract_model import AbstractPageModel from docling.models.abstract_model import AbstractPageModel
class TableStructureModel(AbstractPageModel): class TableStructureModel(AbstractPageModel):
def __init__(self, config): def __init__(
self.config = config self, enabled: bool, artifacts_path: Path, options: TableStructureOptions
self.do_cell_matching = config["do_cell_matching"] ):
self.mode = config["mode"] self.options = options
self.do_cell_matching = self.options.do_cell_matching
self.mode = self.options.mode
self.enabled = config["enabled"] self.enabled = enabled
if self.enabled: if self.enabled:
artifacts_path: Path = config["artifacts_path"] artifacts_path: Path = artifacts_path
if self.mode == TableFormerMode.ACCURATE: if self.mode == TableFormerMode.ACCURATE:
artifacts_path = artifacts_path / "fat" artifacts_path = artifacts_path / "fat"

View File

@ -16,8 +16,11 @@ from docling.models.base_ocr_model import BaseOcrModel
from docling.models.ds_glm_model import GlmModel from docling.models.ds_glm_model import GlmModel
from docling.models.easyocr_model import EasyOcrModel from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import PagePreprocessingModel from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.table_structure_model import TableStructureModel from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel from docling.models.tesseract_ocr_model import TesseractOcrModel
@ -32,6 +35,7 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
def __init__(self, pipeline_options: PdfPipelineOptions): def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options) super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions
if not pipeline_options.artifacts_path: if not pipeline_options.artifacts_path:
artifacts_path = self.download_models_hf() artifacts_path = self.download_models_hf()
@ -39,48 +43,38 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
self.artifacts_path = Path(artifacts_path) self.artifacts_path = Path(artifacts_path)
self.glm_model = GlmModel(config={}) self.glm_model = GlmModel(config={})
ocr_model: BaseOcrModel if ocr_model := self.get_ocr_model() is None:
if isinstance(pipeline_options.ocr_options, EasyOcrOptions):
ocr_model = EasyOcrModel(
enabled=pipeline_options.do_ocr,
options=pipeline_options.ocr_options,
)
elif isinstance(pipeline_options.ocr_options, TesseractCliOcrOptions):
ocr_model = TesseractOcrCliModel(
enabled=pipeline_options.do_ocr,
options=pipeline_options.ocr_options,
)
elif isinstance(pipeline_options.ocr_options, TesseractOcrOptions):
ocr_model = TesseractOcrModel(
enabled=pipeline_options.do_ocr,
options=pipeline_options.ocr_options,
)
else:
raise RuntimeError( raise RuntimeError(
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}." f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
) )
self.model_pipe = [ self.model_pipe = [
# Pre-processing
PagePreprocessingModel( PagePreprocessingModel(
config={"images_scale": pipeline_options.images_scale} options=PagePreprocessingOptions(
images_scale=pipeline_options.images_scale
)
), ),
# OCR
ocr_model, ocr_model,
# Layout model
LayoutModel( LayoutModel(
config={ artifacts_path=artifacts_path
"artifacts_path": artifacts_path / StandardPdfModelPipeline._layout_model_path
/ StandardPdfModelPipeline._layout_model_path
}
), ),
# Table structure model
TableStructureModel( TableStructureModel(
config={ enabled=pipeline_options.do_table_structure,
"artifacts_path": artifacts_path artifacts_path=artifacts_path
/ StandardPdfModelPipeline._table_model_path, / StandardPdfModelPipeline._table_model_path,
"enabled": pipeline_options.do_table_structure, options=pipeline_options.table_structure_options,
"do_cell_matching": pipeline_options.table_structure_options.do_cell_matching, ),
"mode": pipeline_options.table_structure_options.mode, # Page assemble
} PageAssembleModel(
options=PageAssembleOptions(
keep_images=pipeline_options.images_scale is not None
)
), ),
PageAssembleModel(config={"images_scale": pipeline_options.images_scale}),
] ]
self.enrichment_pipe = [ self.enrichment_pipe = [
@ -102,6 +96,24 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
return Path(download_path) return Path(download_path)
def get_ocr_model(self) -> Optional[BaseOcrModel]:
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
return EasyOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions):
return TesseractOcrCliModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, TesseractOcrOptions):
return TesseractOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
return None
def initialize_page(self, doc: InputDocument, page: Page) -> Page: def initialize_page(self, doc: InputDocument, page: Page) -> Page:
page._backend = doc._backend.load_page(page.page_no) page._backend = doc._backend.load_page(page.page_no)
page.size = page._backend.get_size() page.size = page._backend.get_size()

View File

@ -1,4 +1,5 @@
import json import json
import warnings
from pathlib import Path from pathlib import Path
from typing import List from typing import List
@ -235,8 +236,10 @@ def verify_conversion_result_v1(
doc_pred_pages: List[Page] = doc_result.pages doc_pred_pages: List[Page] = doc_result.pages
doc_pred: DsDocument = doc_result.legacy_output doc_pred: DsDocument = doc_result.legacy_output
doc_pred_md = doc_result.render_as_markdown() with warnings.catch_warnings():
doc_pred_dt = doc_result.render_as_doctags() warnings.simplefilter("ignore", DeprecationWarning)
doc_pred_md = doc_result.render_as_markdown()
doc_pred_dt = doc_result.render_as_doctags()
engine_suffix = "" if ocr_engine is None else f".{ocr_engine}" engine_suffix = "" if ocr_engine is None else f".{ocr_engine}"
gt_subpath = input_path.parent / "groundtruth" / "docling_v1" / input_path.name gt_subpath = input_path.parent / "groundtruth" / "docling_v1" / input_path.name