mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-11 22:28:31 +00:00
feat: AutoOCR model selecting the best OCR model available and deprecating the usage of EasyOCR (#2391)
* add auto ocr model Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * Apply suggestions from code review Co-authored-by: Christoph Auer <60343111+cau-git@users.noreply.github.com> Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> * add final log warning Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * propagate default options Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * allow rapidocr models download Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove modelscope Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Co-authored-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
This commit is contained in:
@@ -49,7 +49,7 @@ from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AsrPipelineOptions,
|
||||
ConvertPipelineOptions,
|
||||
EasyOcrOptions,
|
||||
OcrAutoOptions,
|
||||
OcrOptions,
|
||||
PaginatedPipelineOptions,
|
||||
PdfBackend,
|
||||
@@ -374,7 +374,7 @@ def convert( # noqa: C901
|
||||
f"Use the option --show-external-plugins to see the options allowed with external plugins."
|
||||
),
|
||||
),
|
||||
] = EasyOcrOptions.kind,
|
||||
] = OcrAutoOptions.kind,
|
||||
ocr_lang: Annotated[
|
||||
Optional[str],
|
||||
typer.Option(
|
||||
|
||||
@@ -38,6 +38,7 @@ class _AvailableModels(str, Enum):
|
||||
SMOLDOCLING = "smoldocling"
|
||||
SMOLDOCLING_MLX = "smoldocling_mlx"
|
||||
GRANITE_VISION = "granite_vision"
|
||||
RAPIDOCR = "rapidocr"
|
||||
EASYOCR = "easyocr"
|
||||
|
||||
|
||||
@@ -46,7 +47,7 @@ _default_models = [
|
||||
_AvailableModels.TABLEFORMER,
|
||||
_AvailableModels.CODE_FORMULA,
|
||||
_AvailableModels.PICTURE_CLASSIFIER,
|
||||
_AvailableModels.EASYOCR,
|
||||
_AvailableModels.RAPIDOCR,
|
||||
]
|
||||
|
||||
|
||||
@@ -115,6 +116,7 @@ def download(
|
||||
with_smoldocling=_AvailableModels.SMOLDOCLING in to_download,
|
||||
with_smoldocling_mlx=_AvailableModels.SMOLDOCLING_MLX in to_download,
|
||||
with_granite_vision=_AvailableModels.GRANITE_VISION in to_download,
|
||||
with_rapidocr=_AvailableModels.RAPIDOCR in to_download,
|
||||
with_easyocr=_AvailableModels.EASYOCR in to_download,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,6 +81,13 @@ class OcrOptions(BaseOptions):
|
||||
)
|
||||
|
||||
|
||||
class OcrAutoOptions(OcrOptions):
|
||||
"""Options for pick OCR engine automatically."""
|
||||
|
||||
kind: ClassVar[Literal["auto"]] = "auto"
|
||||
lang: List[str] = []
|
||||
|
||||
|
||||
class RapidOcrOptions(OcrOptions):
|
||||
"""Options for the RapidOCR engine."""
|
||||
|
||||
@@ -255,6 +262,7 @@ class PdfBackend(str, Enum):
|
||||
class OcrEngine(str, Enum):
|
||||
"""Enum of valid OCR engines."""
|
||||
|
||||
AUTO = "auto"
|
||||
EASYOCR = "easyocr"
|
||||
TESSERACT_CLI = "tesseract_cli"
|
||||
TESSERACT = "tesseract"
|
||||
@@ -336,7 +344,7 @@ class PdfPipelineOptions(PaginatedPipelineOptions):
|
||||
# If True, text from backend will be used instead of generated text
|
||||
|
||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||
ocr_options: OcrOptions = EasyOcrOptions()
|
||||
ocr_options: OcrOptions = OcrAutoOptions()
|
||||
layout_options: LayoutOptions = LayoutOptions()
|
||||
|
||||
images_scale: float = 1.0
|
||||
|
||||
132
docling/models/auto_ocr_model.py
Normal file
132
docling/models/auto_ocr_model.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
EasyOcrOptions,
|
||||
OcrAutoOptions,
|
||||
OcrMacOptions,
|
||||
OcrOptions,
|
||||
RapidOcrOptions,
|
||||
)
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.ocr_mac_model import OcrMacModel
|
||||
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OcrAutoModel(BaseOcrModel):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Path],
|
||||
options: OcrAutoOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
super().__init__(
|
||||
enabled=enabled,
|
||||
artifacts_path=artifacts_path,
|
||||
options=options,
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
self.options: OcrAutoOptions
|
||||
|
||||
self._engine: Optional[BaseOcrModel] = None
|
||||
if self.enabled:
|
||||
if "darwin" == sys.platform:
|
||||
try:
|
||||
from ocrmac import ocrmac
|
||||
|
||||
self._engine = OcrMacModel(
|
||||
enabled=self.enabled,
|
||||
artifacts_path=artifacts_path,
|
||||
options=OcrMacOptions(
|
||||
bitmap_area_threshold=self.options.bitmap_area_threshold,
|
||||
force_full_page_ocr=self.options.force_full_page_ocr,
|
||||
),
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
_log.info("Auto OCR model selected ocrmac.")
|
||||
except ImportError:
|
||||
_log.info("ocrmac cannot be used because ocrmac is not installed.")
|
||||
|
||||
if self._engine is None:
|
||||
try:
|
||||
import onnxruntime
|
||||
from rapidocr import EngineType, RapidOCR # type: ignore
|
||||
|
||||
self._engine = RapidOcrModel(
|
||||
enabled=self.enabled,
|
||||
artifacts_path=artifacts_path,
|
||||
options=RapidOcrOptions(
|
||||
backend="onnxruntime",
|
||||
bitmap_area_threshold=self.options.bitmap_area_threshold,
|
||||
force_full_page_ocr=self.options.force_full_page_ocr,
|
||||
),
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
_log.info("Auto OCR model selected rapidocr with onnxruntime.")
|
||||
except ImportError:
|
||||
_log.info(
|
||||
"rapidocr cannot be used because onnxruntime is not installed."
|
||||
)
|
||||
|
||||
if self._engine is None:
|
||||
try:
|
||||
import easyocr
|
||||
|
||||
self._engine = EasyOcrModel(
|
||||
enabled=self.enabled,
|
||||
artifacts_path=artifacts_path,
|
||||
options=EasyOcrOptions(
|
||||
bitmap_area_threshold=self.options.bitmap_area_threshold,
|
||||
force_full_page_ocr=self.options.force_full_page_ocr,
|
||||
),
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
_log.info("Auto OCR model selected easyocr.")
|
||||
except ImportError:
|
||||
_log.info("easyocr cannot be used because it is not installed.")
|
||||
|
||||
if self._engine is None:
|
||||
try:
|
||||
import torch
|
||||
from rapidocr import EngineType, RapidOCR # type: ignore
|
||||
|
||||
self._engine = RapidOcrModel(
|
||||
enabled=self.enabled,
|
||||
artifacts_path=artifacts_path,
|
||||
options=RapidOcrOptions(
|
||||
backend="torch",
|
||||
bitmap_area_threshold=self.options.bitmap_area_threshold,
|
||||
force_full_page_ocr=self.options.force_full_page_ocr,
|
||||
),
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
_log.info("Auto OCR model selected rapidocr with torch.")
|
||||
except ImportError:
|
||||
_log.info(
|
||||
"rapidocr cannot be used because rapidocr or torch is not installed."
|
||||
)
|
||||
|
||||
if self._engine is None:
|
||||
_log.warning("No OCR engine found. Please review the install details.")
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
if not self.enabled or self._engine is None:
|
||||
yield from page_batch
|
||||
return
|
||||
yield from self._engine(conv_res, page_batch)
|
||||
|
||||
@classmethod
|
||||
def get_options_type(cls) -> Type[OcrOptions]:
|
||||
return OcrAutoOptions
|
||||
@@ -1,4 +1,5 @@
|
||||
def ocr_engines():
|
||||
from docling.models.auto_ocr_model import OcrAutoModel
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.ocr_mac_model import OcrMacModel
|
||||
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||
@@ -7,6 +8,7 @@ def ocr_engines():
|
||||
|
||||
return {
|
||||
"ocr_engines": [
|
||||
OcrAutoModel,
|
||||
EasyOcrModel,
|
||||
OcrMacModel,
|
||||
RapidOcrModel,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from typing import Literal, Optional, Type, TypedDict
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
@@ -18,11 +18,67 @@ from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
from docling.utils.utils import download_url_with_progress
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
_ModelPathEngines = Literal["onnxruntime", "torch"]
|
||||
_ModelPathTypes = Literal[
|
||||
"det_model_path", "cls_model_path", "rec_model_path", "rec_keys_path"
|
||||
]
|
||||
|
||||
|
||||
class _ModelPathDetail(TypedDict):
|
||||
url: str
|
||||
path: str
|
||||
|
||||
|
||||
class RapidOcrModel(BaseOcrModel):
|
||||
_model_repo_folder = "RapidOcr"
|
||||
# from https://github.com/RapidAI/RapidOCR/blob/main/python/rapidocr/default_models.yaml
|
||||
# matching the default config in https://github.com/RapidAI/RapidOCR/blob/main/python/rapidocr/config.yaml
|
||||
# and naming f"{file_info.engine_type.value}.{file_info.ocr_version.value}.{file_info.task_type.value}"
|
||||
_default_models: dict[
|
||||
_ModelPathEngines, dict[_ModelPathTypes, _ModelPathDetail]
|
||||
] = {
|
||||
"onnxruntime": {
|
||||
"det_model_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v3.4.0/onnx/PP-OCRv4/det/ch_PP-OCRv4_det_infer.onnx",
|
||||
"path": "onnx/PP-OCRv4/det/ch_PP-OCRv4_det_infer.onnx",
|
||||
},
|
||||
"cls_model_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v3.4.0/onnx/PP-OCRv4/cls/ch_ppocr_mobile_v2.0_cls_infer.onnx",
|
||||
"path": "onnx/PP-OCRv4/cls/ch_ppocr_mobile_v2.0_cls_infer.onnx",
|
||||
},
|
||||
"rec_model_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v3.4.0/onnx/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer.onnx",
|
||||
"path": "onnx/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer.onnx",
|
||||
},
|
||||
"rec_keys_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v2.0.7/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt",
|
||||
"path": "paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt",
|
||||
},
|
||||
},
|
||||
"torch": {
|
||||
"det_model_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v3.4.0/torch/PP-OCRv4/det/ch_PP-OCRv4_det_infer.pth",
|
||||
"path": "torch/PP-OCRv4/det/ch_PP-OCRv4_det_infer.pth",
|
||||
},
|
||||
"cls_model_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v3.4.0/torch/PP-OCRv4/cls/ch_ptocr_mobile_v2.0_cls_infer.pth",
|
||||
"path": "torch/PP-OCRv4/cls/ch_ptocr_mobile_v2.0_cls_infer.pth",
|
||||
},
|
||||
"rec_model_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v3.4.0/torch/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer.pth",
|
||||
"path": "torch/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer.pth",
|
||||
},
|
||||
"rec_keys_path": {
|
||||
"url": "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v3.4.0/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt",
|
||||
"path": "paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@@ -62,25 +118,66 @@ class RapidOcrModel(BaseOcrModel):
|
||||
}
|
||||
backend_enum = _ALIASES.get(self.options.backend, EngineType.ONNXRUNTIME)
|
||||
|
||||
det_model_path = self.options.det_model_path
|
||||
cls_model_path = self.options.cls_model_path
|
||||
rec_model_path = self.options.rec_model_path
|
||||
rec_keys_path = self.options.rec_keys_path
|
||||
if artifacts_path is not None:
|
||||
det_model_path = (
|
||||
det_model_path
|
||||
or artifacts_path
|
||||
/ self._model_repo_folder
|
||||
/ self._default_models[backend_enum.value]["det_model_path"]["path"]
|
||||
)
|
||||
cls_model_path = (
|
||||
cls_model_path
|
||||
or artifacts_path
|
||||
/ self._model_repo_folder
|
||||
/ self._default_models[backend_enum.value]["cls_model_path"]["path"]
|
||||
)
|
||||
rec_model_path = (
|
||||
rec_model_path
|
||||
or artifacts_path
|
||||
/ self._model_repo_folder
|
||||
/ self._default_models[backend_enum.value]["rec_model_path"]["path"]
|
||||
)
|
||||
rec_keys_path = (
|
||||
rec_keys_path
|
||||
or artifacts_path
|
||||
/ self._model_repo_folder
|
||||
/ self._default_models[backend_enum.value]["rec_keys_path"]["path"]
|
||||
)
|
||||
|
||||
for model_path in (
|
||||
rec_keys_path,
|
||||
cls_model_path,
|
||||
rec_model_path,
|
||||
rec_keys_path,
|
||||
):
|
||||
if model_path is None:
|
||||
continue
|
||||
if not Path(model_path).exists():
|
||||
_log.warning(f"The provided model path {model_path} is not found.")
|
||||
|
||||
params = {
|
||||
# Global settings (these are still correct)
|
||||
"Global.text_score": self.options.text_score,
|
||||
"Global.font_path": self.options.font_path,
|
||||
# "Global.verbose": self.options.print_verbose,
|
||||
# Detection model settings
|
||||
"Det.model_path": self.options.det_model_path,
|
||||
"Det.model_path": det_model_path,
|
||||
"Det.use_cuda": use_cuda,
|
||||
"Det.use_dml": use_dml,
|
||||
"Det.intra_op_num_threads": intra_op_num_threads,
|
||||
# Classification model settings
|
||||
"Cls.model_path": self.options.cls_model_path,
|
||||
"Cls.model_path": cls_model_path,
|
||||
"Cls.use_cuda": use_cuda,
|
||||
"Cls.use_dml": use_dml,
|
||||
"Cls.intra_op_num_threads": intra_op_num_threads,
|
||||
# Recognition model settings
|
||||
"Rec.model_path": self.options.rec_model_path,
|
||||
"Rec.model_path": rec_model_path,
|
||||
"Rec.font_path": self.options.rec_font_path,
|
||||
"Rec.keys_path": self.options.rec_keys_path,
|
||||
"Rec.keys_path": rec_keys_path,
|
||||
"Rec.use_cuda": use_cuda,
|
||||
"Rec.use_dml": use_dml,
|
||||
"Rec.intra_op_num_threads": intra_op_num_threads,
|
||||
@@ -102,6 +199,30 @@ class RapidOcrModel(BaseOcrModel):
|
||||
params=params,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
backend: _ModelPathEngines,
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
if local_dir is None:
|
||||
local_dir = settings.cache_dir / "models" / RapidOcrModel._model_repo_folder
|
||||
|
||||
local_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download models
|
||||
for model_type, model_details in RapidOcrModel._default_models[backend].items():
|
||||
output_path = local_dir / model_details["path"]
|
||||
if output_path.exists() and not force:
|
||||
continue
|
||||
output_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
buf = download_url_with_progress(model_details["url"], progress=progress)
|
||||
with output_path.open("wb") as fw:
|
||||
fw.write(buf.read())
|
||||
|
||||
return local_dir
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
@@ -20,6 +20,7 @@ from docling.models.document_picture_classifier import DocumentPictureClassifier
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
|
||||
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.models.utils.hf_model_download import download_hf_model
|
||||
|
||||
@@ -41,6 +42,7 @@ def download_models(
|
||||
with_smoldocling: bool = False,
|
||||
with_smoldocling_mlx: bool = False,
|
||||
with_granite_vision: bool = False,
|
||||
with_rapidocr: bool = True,
|
||||
with_easyocr: bool = True,
|
||||
):
|
||||
if output_dir is None:
|
||||
@@ -135,6 +137,16 @@ def download_models(
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
if with_rapidocr:
|
||||
for backend in ("torch", "onnxruntime"):
|
||||
_log.info(f"Downloading rapidocr {backend} models...")
|
||||
RapidOcrModel.download_models(
|
||||
backend=backend,
|
||||
local_dir=output_dir / RapidOcrModel._model_repo_folder,
|
||||
force=force,
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
if with_easyocr:
|
||||
_log.info("Downloading easyocr models...")
|
||||
EasyOcrModel.download_models(
|
||||
|
||||
Reference in New Issue
Block a user