mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
use options objects
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
cc9bcc424d
commit
c6e1471e02
@ -72,19 +72,4 @@ class PdfPipelineOptions(PipelineOptions):
|
||||
Field(EasyOcrOptions(), discriminator="kind")
|
||||
)
|
||||
|
||||
keep_page_images: Annotated[
|
||||
bool,
|
||||
Field(
|
||||
deprecated="`keep_page_images` is depreacted, set the value of `images_scale` instead"
|
||||
),
|
||||
] = False # False: page images are removed in the assemble step
|
||||
images_scale: Optional[float] = None # if set, the scale for generated images
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_page_images_from_deprecated(self) -> "PdfPipelineOptions":
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
default_scale = 1.0
|
||||
if self.keep_page_images and self.images_scale is None:
|
||||
self.images_scale = default_scale
|
||||
return self
|
||||
|
@ -2,6 +2,7 @@ import copy
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
from docling_core.types.experimental import CoordOrigin
|
||||
@ -43,11 +44,8 @@ class LayoutModel(AbstractPageModel):
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
config["artifacts_path"]
|
||||
) # TODO temporary
|
||||
def __init__(self, artifacts_path: Path):
|
||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||
|
||||
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
|
||||
MIN_INTERSECTION = 0.2
|
||||
|
@ -2,6 +2,8 @@ import logging
|
||||
import re
|
||||
from typing import Iterable, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
AssembledUnit,
|
||||
FigureElement,
|
||||
@ -16,9 +18,13 @@ from docling.models.layout_model import LayoutModel
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PageAssembleOptions(BaseModel):
|
||||
keep_images: bool = False
|
||||
|
||||
|
||||
class PageAssembleModel(AbstractPageModel):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
def __init__(self, options: PageAssembleOptions):
|
||||
self.options = options
|
||||
|
||||
def sanitize_text(self, lines):
|
||||
if len(lines) <= 1:
|
||||
@ -147,7 +153,7 @@ class PageAssembleModel(AbstractPageModel):
|
||||
)
|
||||
|
||||
# Remove page images (can be disabled)
|
||||
if self.config["images_scale"] is None:
|
||||
if not self.options.keep_images:
|
||||
page._image_cache = {}
|
||||
|
||||
# Unload backend
|
||||
|
@ -1,14 +1,19 @@
|
||||
from typing import Iterable
|
||||
|
||||
from PIL import ImageDraw
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.models.abstract_model import AbstractPageModel
|
||||
|
||||
|
||||
class PagePreprocessingOptions(BaseModel):
|
||||
images_scale: float
|
||||
|
||||
|
||||
class PagePreprocessingModel(AbstractPageModel):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
def __init__(self, options: PagePreprocessingOptions):
|
||||
self.options = options
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
@ -23,7 +28,7 @@ class PagePreprocessingModel(AbstractPageModel):
|
||||
scale=1.0
|
||||
) # puts the page image on the image cache at default scale
|
||||
|
||||
images_scale = self.config["images_scale"]
|
||||
images_scale = self.options.images_scale
|
||||
# user requested scales
|
||||
if images_scale is not None:
|
||||
page._default_image_scale = images_scale
|
||||
|
@ -10,19 +10,21 @@ from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredic
|
||||
from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
||||
from docling.datamodel.pipeline_options import TableFormerMode
|
||||
from docling.datamodel.pipeline_options import TableFormerMode, TableStructureOptions
|
||||
from docling.models.abstract_model import AbstractPageModel
|
||||
|
||||
|
||||
class TableStructureModel(AbstractPageModel):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.do_cell_matching = config["do_cell_matching"]
|
||||
self.mode = config["mode"]
|
||||
def __init__(
|
||||
self, enabled: bool, artifacts_path: Path, options: TableStructureOptions
|
||||
):
|
||||
self.options = options
|
||||
self.do_cell_matching = self.options.do_cell_matching
|
||||
self.mode = self.options.mode
|
||||
|
||||
self.enabled = config["enabled"]
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
artifacts_path: Path = config["artifacts_path"]
|
||||
artifacts_path: Path = artifacts_path
|
||||
|
||||
if self.mode == TableFormerMode.ACCURATE:
|
||||
artifacts_path = artifacts_path / "fat"
|
||||
|
@ -16,8 +16,11 @@ from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.ds_glm_model import GlmModel
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel
|
||||
from docling.models.page_preprocessing_model import PagePreprocessingModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
|
||||
from docling.models.page_preprocessing_model import (
|
||||
PagePreprocessingModel,
|
||||
PagePreprocessingOptions,
|
||||
)
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||
@ -32,6 +35,7 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
|
||||
|
||||
def __init__(self, pipeline_options: PdfPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.pipeline_options: PdfPipelineOptions
|
||||
|
||||
if not pipeline_options.artifacts_path:
|
||||
artifacts_path = self.download_models_hf()
|
||||
@ -39,48 +43,38 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
|
||||
self.artifacts_path = Path(artifacts_path)
|
||||
self.glm_model = GlmModel(config={})
|
||||
|
||||
ocr_model: BaseOcrModel
|
||||
if isinstance(pipeline_options.ocr_options, EasyOcrOptions):
|
||||
ocr_model = EasyOcrModel(
|
||||
enabled=pipeline_options.do_ocr,
|
||||
options=pipeline_options.ocr_options,
|
||||
)
|
||||
elif isinstance(pipeline_options.ocr_options, TesseractCliOcrOptions):
|
||||
ocr_model = TesseractOcrCliModel(
|
||||
enabled=pipeline_options.do_ocr,
|
||||
options=pipeline_options.ocr_options,
|
||||
)
|
||||
elif isinstance(pipeline_options.ocr_options, TesseractOcrOptions):
|
||||
ocr_model = TesseractOcrModel(
|
||||
enabled=pipeline_options.do_ocr,
|
||||
options=pipeline_options.ocr_options,
|
||||
)
|
||||
else:
|
||||
if ocr_model := self.get_ocr_model() is None:
|
||||
raise RuntimeError(
|
||||
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
|
||||
)
|
||||
|
||||
self.model_pipe = [
|
||||
# Pre-processing
|
||||
PagePreprocessingModel(
|
||||
config={"images_scale": pipeline_options.images_scale}
|
||||
options=PagePreprocessingOptions(
|
||||
images_scale=pipeline_options.images_scale
|
||||
)
|
||||
),
|
||||
# OCR
|
||||
ocr_model,
|
||||
# Layout model
|
||||
LayoutModel(
|
||||
config={
|
||||
"artifacts_path": artifacts_path
|
||||
artifacts_path=artifacts_path
|
||||
/ StandardPdfModelPipeline._layout_model_path
|
||||
}
|
||||
),
|
||||
# Table structure model
|
||||
TableStructureModel(
|
||||
config={
|
||||
"artifacts_path": artifacts_path
|
||||
enabled=pipeline_options.do_table_structure,
|
||||
artifacts_path=artifacts_path
|
||||
/ StandardPdfModelPipeline._table_model_path,
|
||||
"enabled": pipeline_options.do_table_structure,
|
||||
"do_cell_matching": pipeline_options.table_structure_options.do_cell_matching,
|
||||
"mode": pipeline_options.table_structure_options.mode,
|
||||
}
|
||||
options=pipeline_options.table_structure_options,
|
||||
),
|
||||
# Page assemble
|
||||
PageAssembleModel(
|
||||
options=PageAssembleOptions(
|
||||
keep_images=pipeline_options.images_scale is not None
|
||||
)
|
||||
),
|
||||
PageAssembleModel(config={"images_scale": pipeline_options.images_scale}),
|
||||
]
|
||||
|
||||
self.enrichment_pipe = [
|
||||
@ -102,6 +96,24 @@ class StandardPdfModelPipeline(PaginatedModelPipeline):
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def get_ocr_model(self) -> Optional[BaseOcrModel]:
|
||||
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
|
||||
return EasyOcrModel(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
options=self.pipeline_options.ocr_options,
|
||||
)
|
||||
elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions):
|
||||
return TesseractOcrCliModel(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
options=self.pipeline_options.ocr_options,
|
||||
)
|
||||
elif isinstance(self.pipeline_options.ocr_options, TesseractOcrOptions):
|
||||
return TesseractOcrModel(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
options=self.pipeline_options.ocr_options,
|
||||
)
|
||||
return None
|
||||
|
||||
def initialize_page(self, doc: InputDocument, page: Page) -> Page:
|
||||
page._backend = doc._backend.load_page(page.page_no)
|
||||
page.size = page._backend.get_size()
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
@ -235,6 +236,8 @@ def verify_conversion_result_v1(
|
||||
|
||||
doc_pred_pages: List[Page] = doc_result.pages
|
||||
doc_pred: DsDocument = doc_result.legacy_output
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
doc_pred_md = doc_result.render_as_markdown()
|
||||
doc_pred_dt = doc_result.render_as_doctags()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user