From 898a497e714066ce1e1275a1548218f6f158c1f2 Mon Sep 17 00:00:00 2001 From: vdaleke Date: Fri, 7 Feb 2025 14:36:28 +0300 Subject: [PATCH] feat: add support for user-provided OCR model The ocr_model field added to the OcrOptions class with a reference to the BaseOcrModel inheritor class. In case the options are not one of supported model options, the class from this field is used. Signed-off-by: vdaleke --- docling/datamodel/pipeline_options.py | 16 ++++++++-------- docling/pipeline/standard_pdf_pipeline.py | 8 ++++++++ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 3b6401b6..c181ac7f 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -2,11 +2,13 @@ import logging import os from enum import Enum from pathlib import Path -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Type, Union from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict +from docling.models.base_ocr_model import BaseOcrModel + _log = logging.getLogger(__name__) @@ -85,6 +87,7 @@ class OcrOptions(BaseModel): bitmap_area_threshold: float = ( 0.05 # percentage of the area for a bitmap to processed with OCR ) + ocr_model: Optional[Type[BaseOcrModel]] = None class RapidOcrOptions(OcrOptions): @@ -151,6 +154,7 @@ class TesseractCliOcrOptions(OcrOptions): kind: Literal["tesseract"] = "tesseract" lang: List[str] = ["fra", "deu", "spa", "eng"] + tesseract_cmd: str = "tesseract" path: Optional[str] = None @@ -164,6 +168,7 @@ class TesseractOcrOptions(OcrOptions): kind: Literal["tesserocr"] = "tesserocr" lang: List[str] = ["fra", "deu", "spa", "eng"] + path: Optional[str] = None model_config = ConfigDict( @@ -176,6 +181,7 @@ class OcrMacOptions(OcrOptions): kind: Literal["ocrmac"] = "ocrmac" lang: List[str] = ["fr-FR", "de-DE", "es-ES", "en-US"] + recognition: str = "accurate" framework: str = "vision" @@ -271,13 +277,7 @@ class PdfPipelineOptions(PipelineOptions): do_picture_description: bool = False # True: run describe pictures in documents table_structure_options: TableStructureOptions = TableStructureOptions() - ocr_options: Union[ - EasyOcrOptions, - TesseractCliOcrOptions, - TesseractOcrOptions, - OcrMacOptions, - RapidOcrOptions, - ] = Field(EasyOcrOptions(), discriminator="kind") + ocr_options: OcrOptions = EasyOcrOptions() picture_description_options: Annotated[ Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions], Field(discriminator="kind"), diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 13e435f9..47110dcd 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -13,6 +13,7 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ( EasyOcrOptions, OcrMacOptions, + OcrOptions, PdfPipelineOptions, PictureDescriptionApiOptions, PictureDescriptionVlmOptions, @@ -73,6 +74,7 @@ class StandardPdfPipeline(PaginatedPipeline): if (ocr_model := self.get_ocr_model(artifacts_path=artifacts_path)) is None: raise RuntimeError( f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}." + " You can provide a custom OCR model class in the options." ) self.build_pipe = [ @@ -190,6 +192,12 @@ class StandardPdfPipeline(PaginatedPipeline): enabled=self.pipeline_options.do_ocr, options=self.pipeline_options.ocr_options, ) + elif isinstance(self.pipeline_options.ocr_options, OcrOptions): + if self.pipeline_options.ocr_options.ocr_model is not None: + return self.pipeline_options.ocr_options.ocr_model( + enabled=self.pipeline_options.do_ocr, + options=self.pipeline_options.ocr_options, + ) return None def get_picture_description_model(