From bbfc0617f276c058b75ba4e4e8e8b6a236cc13ec Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Wed, 2 Oct 2024 10:47:20 +0200 Subject: [PATCH] feat: add options for choosing OCR engine Signed-off-by: Michele Dolfi --- docling/datamodel/pipeline_options.py | 19 ++++++++++- docling/models/base_ocr_model.py | 8 ++--- docling/models/easyocr_model.py | 8 +++-- docling/pipeline/standard_model_pipeline.py | 35 ++++++++++++++++----- 4 files changed, 55 insertions(+), 15 deletions(-) diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 9ea7a77f..6742c412 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -1,6 +1,7 @@ from enum import Enum, auto +from typing import List, Literal, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field class TableFormerMode(str, Enum): @@ -18,8 +19,24 @@ class TableStructureOptions(BaseModel): mode: TableFormerMode = TableFormerMode.FAST +class OcrOptions(BaseModel): + kind: str + + +class EasyOcrOptions(OcrOptions): + kind: Literal["easyocr"] = "easyocr" + lang: List[str] = ["fr", "de", "es", "en"] + + +class TesseractOcrOptions(OcrOptions): + kind: Literal["tesseract"] = "tesseract" + + class PipelineOptions(BaseModel): do_table_structure: bool = True # True: perform table structure extraction do_ocr: bool = True # True: perform OCR, replace programmatic PDF text table_structure_options: TableStructureOptions = TableStructureOptions() + ocr_options: Union[EasyOcrOptions, TesseractOcrOptions] = Field( + EasyOcrOptions(), discriminator="kind" + ) diff --git a/docling/models/base_ocr_model.py b/docling/models/base_ocr_model.py index 3b3c261e..4139d689 100644 --- a/docling/models/base_ocr_model.py +++ b/docling/models/base_ocr_model.py @@ -3,21 +3,21 @@ import logging from abc import abstractmethod from typing import Iterable, List, Tuple -import numpy import numpy as np from PIL import Image, ImageDraw from rtree import index from scipy.ndimage import find_objects, label from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page +from docling.datamodel.pipeline_options import OcrOptions _log = logging.getLogger(__name__) class BaseOcrModel: - def __init__(self, config): - self.config = config - self.enabled = config["enabled"] + def __init__(self, enabled: bool, options: OcrOptions): + self.enabled = enabled + self.options = options # Computes the optimum amount and coordinates of rectangles to OCR on a given page def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]: diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index 5fb4066b..5fd36ca8 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -4,21 +4,23 @@ from typing import Iterable import numpy from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page +from docling.datamodel.pipeline_options import EasyOcrOptions from docling.models.base_ocr_model import BaseOcrModel _log = logging.getLogger(__name__) class EasyOcrModel(BaseOcrModel): - def __init__(self, config): - super().__init__(config) + def __init__(self, enabled: bool, options: EasyOcrOptions): + super().__init__(enabled=enabled, options=options) + self.options: EasyOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. if self.enabled: import easyocr - self.reader = easyocr.Reader(config["lang"]) + self.reader = easyocr.Reader(lang_list=self.options.lang) def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: diff --git a/docling/pipeline/standard_model_pipeline.py b/docling/pipeline/standard_model_pipeline.py index 3532fea6..4f3d0214 100644 --- a/docling/pipeline/standard_model_pipeline.py +++ b/docling/pipeline/standard_model_pipeline.py @@ -1,6 +1,11 @@ from pathlib import Path -from docling.datamodel.pipeline_options import PipelineOptions +from docling.datamodel.pipeline_options import ( + EasyOcrOptions, + PipelineOptions, + TesseractOcrOptions, +) +from docling.models.base_ocr_model import BaseOcrModel from docling.models.easyocr_model import EasyOcrModel from docling.models.layout_model import LayoutModel from docling.models.table_structure_model import TableStructureModel @@ -14,19 +19,35 @@ class StandardModelPipeline(BaseModelPipeline): def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions): super().__init__(artifacts_path, pipeline_options) + ocr_model: BaseOcrModel + if isinstance(pipeline_options.ocr_options, EasyOcrOptions): + ocr_model = EasyOcrModel( + enabled=pipeline_options.do_ocr, + options=pipeline_options.ocr_options, + ) + elif isinstance(pipeline_options.ocr_options, TesseractOcrOptions): + raise NotImplemented() + # TODO + # ocr_model = TesseractOcrModel( + # enabled=pipeline_options.do_ocr, + # options=pipeline_options.ocr_options, + # ) + else: + raise RuntimeError( + f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}." + ) + self.model_pipe = [ - EasyOcrModel( - config={ - "lang": ["fr", "de", "es", "en"], - "enabled": pipeline_options.do_ocr, - } - ), + # OCR + ocr_model, + # Layout LayoutModel( config={ "artifacts_path": artifacts_path / StandardModelPipeline._layout_model_path } ), + # Table structure TableStructureModel( config={ "artifacts_path": artifacts_path