fix: Improve the pydantic objects in the pipeline_options and imports.

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Nikos Livathinos 2024-12-06 14:56:35 +01:00
parent 71f3a7ac3c
commit f63e5ef3b5
6 changed files with 36 additions and 45 deletions

View File

@ -2,9 +2,9 @@ import logging
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
@ -28,11 +28,12 @@ class AcceleratorOptions(BaseSettings):
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
)
# num_threads: int = Field(default=4, validation_alias="omp_num_threads")
num_threads: int = 4
device: AcceleratorDevice = AcceleratorDevice.AUTO
def __init__(self, **data):
@model_validator(mode="before")
@classmethod
def check_alternative_envvars(cls, data: Any) -> Any:
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
@ -42,21 +43,22 @@ class AcceleratorOptions(BaseSettings):
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
input_num_threads = data.get("num_threads")
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
omp_num_threads = os.getenv("OMP_NUM_THREADS")
if docling_num_threads is None and omp_num_threads is not None:
try:
data["num_threads"] = int(omp_num_threads)
except ValueError:
_log.error(
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
omp_num_threads,
)
super().__init__(**data)
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
omp_num_threads = os.getenv("OMP_NUM_THREADS")
if docling_num_threads is None and omp_num_threads is not None:
try:
data["num_threads"] = int(omp_num_threads)
except ValueError:
_log.error(
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
omp_num_threads,
)
return data
class TableFormerMode(str, Enum):
@ -132,18 +134,17 @@ class EasyOcrOptions(OcrOptions):
kind: Literal["easyocr"] = "easyocr"
lang: List[str] = ["fr", "de", "es", "en"]
use_gpu: bool = Field(
default=True,
deprecated=(
"Field `use_gpu` is deprecated. "
"When `use_gpu == True and accelerator_options.device == AcceleratorDevice.CUDA` "
"then the GPU is used to run EasyOCR. "
"When `use_gpu == False`, EasyOCR runs in CPU"
use_gpu: Annotated[
int,
Field(
deprecated="Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
),
)
] = True
model_storage_directory: Optional[str] = None
download_enabled: bool = True # same default as easyocr.Reader
download_enabled: bool = True
model_config = ConfigDict(
extra="forbid",

View File

@ -14,7 +14,7 @@ from docling.datamodel.pipeline_options import (
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils import accelerator_utils as au
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
@ -43,7 +43,7 @@ class EasyOcrModel(BaseOcrModel):
use_gpu = False
if self.options.use_gpu:
device = au.decide_device(accelerator_options.device)
device = decide_device(accelerator_options.device)
# Enable easyocr GPU if running on CUDA, MPS
use_gpu = device in [
AcceleratorDevice.CUDA,

View File

@ -20,7 +20,7 @@ from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils import accelerator_utils as au
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder
@ -50,7 +50,7 @@ class LayoutModel(BasePageModel):
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
device = au.decide_device(accelerator_options.device)
device = decide_device(accelerator_options.device)
self.layout_predictor = LayoutPredictor(
artifacts_path, device, accelerator_options.num_threads
)

View File

@ -13,7 +13,7 @@ from docling.datamodel.pipeline_options import (
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils import accelerator_utils as au
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
@ -41,7 +41,7 @@ class RapidOcrModel(BaseOcrModel):
)
# Decide the accelerator devices
device = au.decide_device(accelerator_options.device)
device = decide_device(accelerator_options.device)
use_cuda = device == AcceleratorDevice.CUDA
use_dml = accelerator_options.device == AcceleratorDevice.AUTO
intra_op_num_threads = accelerator_options.num_threads

View File

@ -16,7 +16,7 @@ from docling.datamodel.pipeline_options import (
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils import accelerator_utils as au
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
@ -40,7 +40,7 @@ class TableStructureModel(BasePageModel):
# Third Party
import docling_ibm_models.tableformer.common as c
device = au.decide_device(accelerator_options.device)
device = decide_device(accelerator_options.device)
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
self.tm_config["model"]["save_dir"] = artifacts_path

View File

@ -38,9 +38,6 @@ _log = logging.getLogger(__name__)
class StandardPdfPipeline(PaginatedPipeline):
# TODO: Revise after having the models in HF
# _layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt"
_layout_model_path = "model_artifacts/layout"
_table_model_path = "model_artifacts/tableformer"
@ -103,13 +100,6 @@ class StandardPdfPipeline(PaginatedPipeline):
) -> Path:
from huggingface_hub import snapshot_download
# TODO: Revise after having the models in HF
# download_path = snapshot_download(
# repo_id="ds4sd/docling-models",
# force_download=force,
# local_dir=local_dir,
# revision="v2.0.1",
# )
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,