feat: add options for choosing OCR engine

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2024-10-02 10:47:20 +02:00
parent cde671cf34
commit bbfc0617f2
4 changed files with 55 additions and 15 deletions

View File

@ -1,6 +1,7 @@
from enum import Enum, auto
from typing import List, Literal, Union
from pydantic import BaseModel
from pydantic import BaseModel, Field
class TableFormerMode(str, Enum):
@ -18,8 +19,24 @@ class TableStructureOptions(BaseModel):
mode: TableFormerMode = TableFormerMode.FAST
class OcrOptions(BaseModel):
kind: str
class EasyOcrOptions(OcrOptions):
kind: Literal["easyocr"] = "easyocr"
lang: List[str] = ["fr", "de", "es", "en"]
class TesseractOcrOptions(OcrOptions):
kind: Literal["tesseract"] = "tesseract"
class PipelineOptions(BaseModel):
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[EasyOcrOptions, TesseractOcrOptions] = Field(
EasyOcrOptions(), discriminator="kind"
)

View File

@ -3,21 +3,21 @@ import logging
from abc import abstractmethod
from typing import Iterable, List, Tuple
import numpy
import numpy as np
from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import find_objects, label
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.datamodel.pipeline_options import OcrOptions
_log = logging.getLogger(__name__)
class BaseOcrModel:
def __init__(self, config):
self.config = config
self.enabled = config["enabled"]
def __init__(self, enabled: bool, options: OcrOptions):
self.enabled = enabled
self.options = options
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]:

View File

@ -4,21 +4,23 @@ from typing import Iterable
import numpy
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.datamodel.pipeline_options import EasyOcrOptions
from docling.models.base_ocr_model import BaseOcrModel
_log = logging.getLogger(__name__)
class EasyOcrModel(BaseOcrModel):
def __init__(self, config):
super().__init__(config)
def __init__(self, enabled: bool, options: EasyOcrOptions):
super().__init__(enabled=enabled, options=options)
self.options: EasyOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
if self.enabled:
import easyocr
self.reader = easyocr.Reader(config["lang"])
self.reader = easyocr.Reader(lang_list=self.options.lang)
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:

View File

@ -1,6 +1,11 @@
from pathlib import Path
from docling.datamodel.pipeline_options import PipelineOptions
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
PipelineOptions,
TesseractOcrOptions,
)
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.table_structure_model import TableStructureModel
@ -14,19 +19,35 @@ class StandardModelPipeline(BaseModelPipeline):
def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
super().__init__(artifacts_path, pipeline_options)
ocr_model: BaseOcrModel
if isinstance(pipeline_options.ocr_options, EasyOcrOptions):
ocr_model = EasyOcrModel(
enabled=pipeline_options.do_ocr,
options=pipeline_options.ocr_options,
)
elif isinstance(pipeline_options.ocr_options, TesseractOcrOptions):
raise NotImplemented()
# TODO
# ocr_model = TesseractOcrModel(
# enabled=pipeline_options.do_ocr,
# options=pipeline_options.ocr_options,
# )
else:
raise RuntimeError(
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
)
self.model_pipe = [
EasyOcrModel(
config={
"lang": ["fr", "de", "es", "en"],
"enabled": pipeline_options.do_ocr,
}
),
# OCR
ocr_model,
# Layout
LayoutModel(
config={
"artifacts_path": artifacts_path
/ StandardModelPipeline._layout_model_path
}
),
# Table structure
TableStructureModel(
config={
"artifacts_path": artifacts_path