fix: Improve the pydantic objects in the pipeline_options and imports.

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Nikos Livathinos 2024-12-06 14:56:35 +01:00
parent 71f3a7ac3c
commit f63e5ef3b5
6 changed files with 36 additions and 45 deletions

View File

@ -2,9 +2,9 @@ import logging
import os import os
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
PydanticBaseSettingsSource, PydanticBaseSettingsSource,
@ -28,11 +28,12 @@ class AcceleratorOptions(BaseSettings):
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
) )
# num_threads: int = Field(default=4, validation_alias="omp_num_threads")
num_threads: int = 4 num_threads: int = 4
device: AcceleratorDevice = AcceleratorDevice.AUTO device: AcceleratorDevice = AcceleratorDevice.AUTO
def __init__(self, **data): @model_validator(mode="before")
@classmethod
def check_alternative_envvars(cls, data: Any) -> Any:
r""" r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS. Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set. The alternative envvar is used only if it is valid and the regular envvar is not set.
@ -42,6 +43,7 @@ class AcceleratorOptions(BaseSettings):
parameter in settings initialization, Pydantic treats the parameter provided in __init__() parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter. as an extra input instead of simply overwriting the evvar value for that parameter.
""" """
if isinstance(data, dict):
input_num_threads = data.get("num_threads") input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar # Check if to set the num_threads from the alternative envvar
@ -56,7 +58,7 @@ class AcceleratorOptions(BaseSettings):
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'", "Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
omp_num_threads, omp_num_threads,
) )
super().__init__(**data) return data
class TableFormerMode(str, Enum): class TableFormerMode(str, Enum):
@ -132,18 +134,17 @@ class EasyOcrOptions(OcrOptions):
kind: Literal["easyocr"] = "easyocr" kind: Literal["easyocr"] = "easyocr"
lang: List[str] = ["fr", "de", "es", "en"] lang: List[str] = ["fr", "de", "es", "en"]
use_gpu: bool = Field( use_gpu: Annotated[
default=True, int,
deprecated=( Field(
"Field `use_gpu` is deprecated. " deprecated="Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
"When `use_gpu == True and accelerator_options.device == AcceleratorDevice.CUDA` " "When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
"then the GPU is used to run EasyOCR. " "to run EasyOCR. Otherwise, EasyOCR runs in CPU."
"When `use_gpu == False`, EasyOCR runs in CPU"
), ),
) ] = True
model_storage_directory: Optional[str] = None model_storage_directory: Optional[str] = None
download_enabled: bool = True # same default as easyocr.Reader download_enabled: bool = True
model_config = ConfigDict( model_config = ConfigDict(
extra="forbid", extra="forbid",

View File

@ -14,7 +14,7 @@ from docling.datamodel.pipeline_options import (
) )
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel from docling.models.base_ocr_model import BaseOcrModel
from docling.utils import accelerator_utils as au from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -43,7 +43,7 @@ class EasyOcrModel(BaseOcrModel):
use_gpu = False use_gpu = False
if self.options.use_gpu: if self.options.use_gpu:
device = au.decide_device(accelerator_options.device) device = decide_device(accelerator_options.device)
# Enable easyocr GPU if running on CUDA, MPS # Enable easyocr GPU if running on CUDA, MPS
use_gpu = device in [ use_gpu = device in [
AcceleratorDevice.CUDA, AcceleratorDevice.CUDA,

View File

@ -20,7 +20,7 @@ from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel
from docling.utils import accelerator_utils as au from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
@ -50,7 +50,7 @@ class LayoutModel(BasePageModel):
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions): def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
device = au.decide_device(accelerator_options.device) device = decide_device(accelerator_options.device)
self.layout_predictor = LayoutPredictor( self.layout_predictor = LayoutPredictor(
artifacts_path, device, accelerator_options.num_threads artifacts_path, device, accelerator_options.num_threads
) )

View File

@ -13,7 +13,7 @@ from docling.datamodel.pipeline_options import (
) )
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel from docling.models.base_ocr_model import BaseOcrModel
from docling.utils import accelerator_utils as au from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -41,7 +41,7 @@ class RapidOcrModel(BaseOcrModel):
) )
# Decide the accelerator devices # Decide the accelerator devices
device = au.decide_device(accelerator_options.device) device = decide_device(accelerator_options.device)
use_cuda = device == AcceleratorDevice.CUDA use_cuda = device == AcceleratorDevice.CUDA
use_dml = accelerator_options.device == AcceleratorDevice.AUTO use_dml = accelerator_options.device == AcceleratorDevice.AUTO
intra_op_num_threads = accelerator_options.num_threads intra_op_num_threads = accelerator_options.num_threads

View File

@ -16,7 +16,7 @@ from docling.datamodel.pipeline_options import (
) )
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel
from docling.utils import accelerator_utils as au from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
@ -40,7 +40,7 @@ class TableStructureModel(BasePageModel):
# Third Party # Third Party
import docling_ibm_models.tableformer.common as c import docling_ibm_models.tableformer.common as c
device = au.decide_device(accelerator_options.device) device = decide_device(accelerator_options.device)
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json") self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
self.tm_config["model"]["save_dir"] = artifacts_path self.tm_config["model"]["save_dir"] = artifacts_path

View File

@ -38,9 +38,6 @@ _log = logging.getLogger(__name__)
class StandardPdfPipeline(PaginatedPipeline): class StandardPdfPipeline(PaginatedPipeline):
# TODO: Revise after having the models in HF
# _layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt"
_layout_model_path = "model_artifacts/layout" _layout_model_path = "model_artifacts/layout"
_table_model_path = "model_artifacts/tableformer" _table_model_path = "model_artifacts/tableformer"
@ -103,13 +100,6 @@ class StandardPdfPipeline(PaginatedPipeline):
) -> Path: ) -> Path:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
# TODO: Revise after having the models in HF
# download_path = snapshot_download(
# repo_id="ds4sd/docling-models",
# force_download=force,
# local_dir=local_dir,
# revision="v2.0.1",
# )
download_path = snapshot_download( download_path = snapshot_download(
repo_id="ds4sd/docling-models", repo_id="ds4sd/docling-models",
force_download=force, force_download=force,