mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
fix: Improve the pydantic objects in the pipeline_options and imports.
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
71f3a7ac3c
commit
f63e5ef3b5
@ -2,9 +2,9 @@ import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
@ -28,11 +28,12 @@ class AcceleratorOptions(BaseSettings):
|
||||
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
|
||||
)
|
||||
|
||||
# num_threads: int = Field(default=4, validation_alias="omp_num_threads")
|
||||
num_threads: int = 4
|
||||
device: AcceleratorDevice = AcceleratorDevice.AUTO
|
||||
|
||||
def __init__(self, **data):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_alternative_envvars(cls, data: Any) -> Any:
|
||||
r"""
|
||||
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
|
||||
The alternative envvar is used only if it is valid and the regular envvar is not set.
|
||||
@ -42,21 +43,22 @@ class AcceleratorOptions(BaseSettings):
|
||||
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
|
||||
as an extra input instead of simply overwriting the evvar value for that parameter.
|
||||
"""
|
||||
input_num_threads = data.get("num_threads")
|
||||
if isinstance(data, dict):
|
||||
input_num_threads = data.get("num_threads")
|
||||
|
||||
# Check if to set the num_threads from the alternative envvar
|
||||
if input_num_threads is None:
|
||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
||||
if docling_num_threads is None and omp_num_threads is not None:
|
||||
try:
|
||||
data["num_threads"] = int(omp_num_threads)
|
||||
except ValueError:
|
||||
_log.error(
|
||||
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
|
||||
omp_num_threads,
|
||||
)
|
||||
super().__init__(**data)
|
||||
# Check if to set the num_threads from the alternative envvar
|
||||
if input_num_threads is None:
|
||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
||||
if docling_num_threads is None and omp_num_threads is not None:
|
||||
try:
|
||||
data["num_threads"] = int(omp_num_threads)
|
||||
except ValueError:
|
||||
_log.error(
|
||||
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
|
||||
omp_num_threads,
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class TableFormerMode(str, Enum):
|
||||
@ -132,18 +134,17 @@ class EasyOcrOptions(OcrOptions):
|
||||
|
||||
kind: Literal["easyocr"] = "easyocr"
|
||||
lang: List[str] = ["fr", "de", "es", "en"]
|
||||
use_gpu: bool = Field(
|
||||
default=True,
|
||||
deprecated=(
|
||||
"Field `use_gpu` is deprecated. "
|
||||
"When `use_gpu == True and accelerator_options.device == AcceleratorDevice.CUDA` "
|
||||
"then the GPU is used to run EasyOCR. "
|
||||
"When `use_gpu == False`, EasyOCR runs in CPU"
|
||||
use_gpu: Annotated[
|
||||
int,
|
||||
Field(
|
||||
deprecated="Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
|
||||
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
|
||||
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
|
||||
),
|
||||
)
|
||||
] = True
|
||||
|
||||
model_storage_directory: Optional[str] = None
|
||||
download_enabled: bool = True # same default as easyocr.Reader
|
||||
download_enabled: bool = True
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
|
@ -14,7 +14,7 @@ from docling.datamodel.pipeline_options import (
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@ -43,7 +43,7 @@ class EasyOcrModel(BaseOcrModel):
|
||||
|
||||
use_gpu = False
|
||||
if self.options.use_gpu:
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
device = decide_device(accelerator_options.device)
|
||||
# Enable easyocr GPU if running on CUDA, MPS
|
||||
use_gpu = device in [
|
||||
AcceleratorDevice.CUDA,
|
||||
|
@ -20,7 +20,7 @@ from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
@ -50,7 +50,7 @@ class LayoutModel(BasePageModel):
|
||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||
|
||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
device = decide_device(accelerator_options.device)
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
artifacts_path, device, accelerator_options.num_threads
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ from docling.datamodel.pipeline_options import (
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@ -41,7 +41,7 @@ class RapidOcrModel(BaseOcrModel):
|
||||
)
|
||||
|
||||
# Decide the accelerator devices
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
device = decide_device(accelerator_options.device)
|
||||
use_cuda = device == AcceleratorDevice.CUDA
|
||||
use_dml = accelerator_options.device == AcceleratorDevice.AUTO
|
||||
intra_op_num_threads = accelerator_options.num_threads
|
||||
|
@ -16,7 +16,7 @@ from docling.datamodel.pipeline_options import (
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class TableStructureModel(BasePageModel):
|
||||
# Third Party
|
||||
import docling_ibm_models.tableformer.common as c
|
||||
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
||||
self.tm_config["model"]["save_dir"] = artifacts_path
|
||||
|
@ -38,9 +38,6 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StandardPdfPipeline(PaginatedPipeline):
|
||||
# TODO: Revise after having the models in HF
|
||||
# _layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt"
|
||||
|
||||
_layout_model_path = "model_artifacts/layout"
|
||||
_table_model_path = "model_artifacts/tableformer"
|
||||
|
||||
@ -103,13 +100,6 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# TODO: Revise after having the models in HF
|
||||
# download_path = snapshot_download(
|
||||
# repo_id="ds4sd/docling-models",
|
||||
# force_download=force,
|
||||
# local_dir=local_dir,
|
||||
# revision="v2.0.1",
|
||||
# )
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/docling-models",
|
||||
force_download=force,
|
||||
|
Loading…
Reference in New Issue
Block a user