use options objects

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2024-10-11 12:58:59 +02:00
parent cc9bcc424d
commit c6e1471e02
7 changed files with 78 additions and 67 deletions

View File

@ -72,19 +72,4 @@ class PdfPipelineOptions(PipelineOptions):
Field(EasyOcrOptions(), discriminator="kind")
)
keep_page_images: Annotated[
bool,
Field(
deprecated="`keep_page_images` is depreacted, set the value of `images_scale` instead"
),
] = False # False: page images are removed in the assemble step
images_scale: Optional[float] = None # if set, the scale for generated images
@model_validator(mode="after")
def set_page_images_from_deprecated(self) -> "PdfPipelineOptions":
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
default_scale = 1.0
if self.keep_page_images and self.images_scale is None:
self.images_scale = default_scale
return self

View File

@ -2,6 +2,7 @@ import copy
import logging
import random
import time
from pathlib import Path
from typing import Iterable, List
from docling_core.types.experimental import CoordOrigin
@ -43,11 +44,8 @@ class LayoutModel(AbstractPageModel):
FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA
def __init__(self, config):
self.config = config
self.layout_predictor = LayoutPredictor(
config["artifacts_path"]
) # TODO temporary
def __init__(self, artifacts_path: Path):
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
MIN_INTERSECTION = 0.2

View File

@ -2,6 +2,8 @@ import logging
import re
from typing import Iterable, List
from pydantic import BaseModel
from docling.datamodel.base_models import (
AssembledUnit,
FigureElement,
@ -16,9 +18,13 @@ from docling.models.layout_model import LayoutModel
_log = logging.getLogger(__name__)
class PageAssembleOptions(BaseModel):
keep_images: bool = False
class PageAssembleModel(AbstractPageModel):
def __init__(self, config):
self.config = config
def __init__(self, options: PageAssembleOptions):
self.options = options
def sanitize_text(self, lines):
if len(lines) <= 1:
@ -147,7 +153,7 @@ class PageAssembleModel(AbstractPageModel):
)
# Remove page images (can be disabled)
if self.config["images_scale"] is None:
if not self.options.keep_images:
page._image_cache = {}
# Unload backend

View File

@ -1,14 +1,19 @@
from typing import Iterable
from PIL import ImageDraw
from pydantic import BaseModel
from docling.datamodel.base_models import Page
from docling.models.abstract_model import AbstractPageModel
class PagePreprocessingOptions(BaseModel):
images_scale: float
class PagePreprocessingModel(AbstractPageModel):
def __init__(self, config):
self.config = config
def __init__(self, options: PagePreprocessingOptions):
self.options = options
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch:
@ -23,7 +28,7 @@ class PagePreprocessingModel(AbstractPageModel):
scale=1.0
) # puts the page image on the image cache at default scale
images_scale = self.config["images_scale"]
images_scale = self.options.images_scale
# user requested scales
if images_scale is not None:
page._default_image_scale = images_scale

View File

@ -10,19 +10,21 @@ from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredic
from PIL import ImageDraw
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
from docling.datamodel.pipeline_options import TableFormerMode
from docling.datamodel.pipeline_options import TableFormerMode, TableStructureOptions
from docling.models.abstract_model import AbstractPageModel
class TableStructureModel(AbstractPageModel):
def __init__(self, config):
self.config = config
self.do_cell_matching = config["do_cell_matching"]
self.mode = config["mode"]
def __init__(
self, enabled: bool, artifacts_path: Path, options: TableStructureOptions
):
self.options = options
self.do_cell_matching = self.options.do_cell_matching
self.mode = self.options.mode
self.enabled = config["enabled"]
self.enabled = enabled
if self.enabled:
artifacts_path: Path = config["artifacts_path"]
artifacts_path: Path = artifacts_path
if self.mode == TableFormerMode.ACCURATE:
artifacts_path = artifacts_path / "fat"

View File

@ -16,8 +16,11 @@ from docling.models.base_ocr_model import BaseOcrModel
from docling.models.ds_glm_model import GlmModel
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel
from docling.models.page_preprocessing_model import PagePreprocessingModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel
@ -32,6 +35,7 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions
if not pipeline_options.artifacts_path:
artifacts_path = self.download_models_hf()
@ -39,48 +43,38 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
self.artifacts_path = Path(artifacts_path)
self.glm_model = GlmModel(config={})
ocr_model: BaseOcrModel
if isinstance(pipeline_options.ocr_options, EasyOcrOptions):
ocr_model = EasyOcrModel(
enabled=pipeline_options.do_ocr,
options=pipeline_options.ocr_options,
)
elif isinstance(pipeline_options.ocr_options, TesseractCliOcrOptions):
ocr_model = TesseractOcrCliModel(
enabled=pipeline_options.do_ocr,
options=pipeline_options.ocr_options,
)
elif isinstance(pipeline_options.ocr_options, TesseractOcrOptions):
ocr_model = TesseractOcrModel(
enabled=pipeline_options.do_ocr,
options=pipeline_options.ocr_options,
)
else:
if ocr_model := self.get_ocr_model() is None:
raise RuntimeError(
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
)
self.model_pipe = [
# Pre-processing
PagePreprocessingModel(
config={"images_scale": pipeline_options.images_scale}
options=PagePreprocessingOptions(
images_scale=pipeline_options.images_scale
)
),
# OCR
ocr_model,
# Layout model
LayoutModel(
config={
"artifacts_path": artifacts_path
artifacts_path=artifacts_path
/ StandardPdfModelPipeline._layout_model_path
}
),
# Table structure model
TableStructureModel(
config={
"artifacts_path": artifacts_path
enabled=pipeline_options.do_table_structure,
artifacts_path=artifacts_path
/ StandardPdfModelPipeline._table_model_path,
"enabled": pipeline_options.do_table_structure,
"do_cell_matching": pipeline_options.table_structure_options.do_cell_matching,
"mode": pipeline_options.table_structure_options.mode,
}
options=pipeline_options.table_structure_options,
),
# Page assemble
PageAssembleModel(
options=PageAssembleOptions(
keep_images=pipeline_options.images_scale is not None
)
),
PageAssembleModel(config={"images_scale": pipeline_options.images_scale}),
]
self.enrichment_pipe = [
@ -102,6 +96,24 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
return Path(download_path)
def get_ocr_model(self) -> Optional[BaseOcrModel]:
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
return EasyOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions):
return TesseractOcrCliModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, TesseractOcrOptions):
return TesseractOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
return None
def initialize_page(self, doc: InputDocument, page: Page) -> Page:
page._backend = doc._backend.load_page(page.page_no)
page.size = page._backend.get_size()

View File

@ -1,4 +1,5 @@
import json
import warnings
from pathlib import Path
from typing import List
@ -235,6 +236,8 @@ def verify_conversion_result_v1(
doc_pred_pages: List[Page] = doc_result.pages
doc_pred: DsDocument = doc_result.legacy_output
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
doc_pred_md = doc_result.render_as_markdown()
doc_pred_dt = doc_result.render_as_doctags()