mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-29 05:24:28 +00:00
fix: Improve the pydantic objects in the pipeline_options and imports.
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
71f3a7ac3c
commit
f63e5ef3b5
@ -2,9 +2,9 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
from pydantic_settings import (
|
from pydantic_settings import (
|
||||||
BaseSettings,
|
BaseSettings,
|
||||||
PydanticBaseSettingsSource,
|
PydanticBaseSettingsSource,
|
||||||
@ -28,11 +28,12 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
|
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# num_threads: int = Field(default=4, validation_alias="omp_num_threads")
|
|
||||||
num_threads: int = 4
|
num_threads: int = 4
|
||||||
device: AcceleratorDevice = AcceleratorDevice.AUTO
|
device: AcceleratorDevice = AcceleratorDevice.AUTO
|
||||||
|
|
||||||
def __init__(self, **data):
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_alternative_envvars(cls, data: Any) -> Any:
|
||||||
r"""
|
r"""
|
||||||
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
|
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
|
||||||
The alternative envvar is used only if it is valid and the regular envvar is not set.
|
The alternative envvar is used only if it is valid and the regular envvar is not set.
|
||||||
@ -42,21 +43,22 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
|
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
|
||||||
as an extra input instead of simply overwriting the evvar value for that parameter.
|
as an extra input instead of simply overwriting the evvar value for that parameter.
|
||||||
"""
|
"""
|
||||||
input_num_threads = data.get("num_threads")
|
if isinstance(data, dict):
|
||||||
|
input_num_threads = data.get("num_threads")
|
||||||
|
|
||||||
# Check if to set the num_threads from the alternative envvar
|
# Check if to set the num_threads from the alternative envvar
|
||||||
if input_num_threads is None:
|
if input_num_threads is None:
|
||||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||||
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
||||||
if docling_num_threads is None and omp_num_threads is not None:
|
if docling_num_threads is None and omp_num_threads is not None:
|
||||||
try:
|
try:
|
||||||
data["num_threads"] = int(omp_num_threads)
|
data["num_threads"] = int(omp_num_threads)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
_log.error(
|
_log.error(
|
||||||
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
|
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
|
||||||
omp_num_threads,
|
omp_num_threads,
|
||||||
)
|
)
|
||||||
super().__init__(**data)
|
return data
|
||||||
|
|
||||||
|
|
||||||
class TableFormerMode(str, Enum):
|
class TableFormerMode(str, Enum):
|
||||||
@ -132,18 +134,17 @@ class EasyOcrOptions(OcrOptions):
|
|||||||
|
|
||||||
kind: Literal["easyocr"] = "easyocr"
|
kind: Literal["easyocr"] = "easyocr"
|
||||||
lang: List[str] = ["fr", "de", "es", "en"]
|
lang: List[str] = ["fr", "de", "es", "en"]
|
||||||
use_gpu: bool = Field(
|
use_gpu: Annotated[
|
||||||
default=True,
|
int,
|
||||||
deprecated=(
|
Field(
|
||||||
"Field `use_gpu` is deprecated. "
|
deprecated="Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
|
||||||
"When `use_gpu == True and accelerator_options.device == AcceleratorDevice.CUDA` "
|
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
|
||||||
"then the GPU is used to run EasyOCR. "
|
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
|
||||||
"When `use_gpu == False`, EasyOCR runs in CPU"
|
|
||||||
),
|
),
|
||||||
)
|
] = True
|
||||||
|
|
||||||
model_storage_directory: Optional[str] = None
|
model_storage_directory: Optional[str] = None
|
||||||
download_enabled: bool = True # same default as easyocr.Reader
|
download_enabled: bool = True
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
extra="forbid",
|
extra="forbid",
|
||||||
|
@ -14,7 +14,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.base_ocr_model import BaseOcrModel
|
from docling.models.base_ocr_model import BaseOcrModel
|
||||||
from docling.utils import accelerator_utils as au
|
from docling.utils.accelerator_utils import decide_device
|
||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
@ -43,7 +43,7 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
|
|
||||||
use_gpu = False
|
use_gpu = False
|
||||||
if self.options.use_gpu:
|
if self.options.use_gpu:
|
||||||
device = au.decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
# Enable easyocr GPU if running on CUDA, MPS
|
# Enable easyocr GPU if running on CUDA, MPS
|
||||||
use_gpu = device in [
|
use_gpu = device in [
|
||||||
AcceleratorDevice.CUDA,
|
AcceleratorDevice.CUDA,
|
||||||
|
@ -20,7 +20,7 @@ from docling.datamodel.document import ConversionResult
|
|||||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel
|
||||||
from docling.utils import accelerator_utils as au
|
from docling.utils.accelerator_utils import decide_device
|
||||||
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ class LayoutModel(BasePageModel):
|
|||||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||||
|
|
||||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||||
device = au.decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
self.layout_predictor = LayoutPredictor(
|
self.layout_predictor = LayoutPredictor(
|
||||||
artifacts_path, device, accelerator_options.num_threads
|
artifacts_path, device, accelerator_options.num_threads
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.base_ocr_model import BaseOcrModel
|
from docling.models.base_ocr_model import BaseOcrModel
|
||||||
from docling.utils import accelerator_utils as au
|
from docling.utils.accelerator_utils import decide_device
|
||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
@ -41,7 +41,7 @@ class RapidOcrModel(BaseOcrModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Decide the accelerator devices
|
# Decide the accelerator devices
|
||||||
device = au.decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
use_cuda = device == AcceleratorDevice.CUDA
|
use_cuda = device == AcceleratorDevice.CUDA
|
||||||
use_dml = accelerator_options.device == AcceleratorDevice.AUTO
|
use_dml = accelerator_options.device == AcceleratorDevice.AUTO
|
||||||
intra_op_num_threads = accelerator_options.num_threads
|
intra_op_num_threads = accelerator_options.num_threads
|
||||||
|
@ -16,7 +16,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel
|
||||||
from docling.utils import accelerator_utils as au
|
from docling.utils.accelerator_utils import decide_device
|
||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ class TableStructureModel(BasePageModel):
|
|||||||
# Third Party
|
# Third Party
|
||||||
import docling_ibm_models.tableformer.common as c
|
import docling_ibm_models.tableformer.common as c
|
||||||
|
|
||||||
device = au.decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
|
|
||||||
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
||||||
self.tm_config["model"]["save_dir"] = artifacts_path
|
self.tm_config["model"]["save_dir"] = artifacts_path
|
||||||
|
@ -38,9 +38,6 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class StandardPdfPipeline(PaginatedPipeline):
|
class StandardPdfPipeline(PaginatedPipeline):
|
||||||
# TODO: Revise after having the models in HF
|
|
||||||
# _layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt"
|
|
||||||
|
|
||||||
_layout_model_path = "model_artifacts/layout"
|
_layout_model_path = "model_artifacts/layout"
|
||||||
_table_model_path = "model_artifacts/tableformer"
|
_table_model_path = "model_artifacts/tableformer"
|
||||||
|
|
||||||
@ -103,13 +100,6 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
) -> Path:
|
) -> Path:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
# TODO: Revise after having the models in HF
|
|
||||||
# download_path = snapshot_download(
|
|
||||||
# repo_id="ds4sd/docling-models",
|
|
||||||
# force_download=force,
|
|
||||||
# local_dir=local_dir,
|
|
||||||
# revision="v2.0.1",
|
|
||||||
# )
|
|
||||||
download_path = snapshot_download(
|
download_path = snapshot_download(
|
||||||
repo_id="ds4sd/docling-models",
|
repo_id="ds4sd/docling-models",
|
||||||
force_download=force,
|
force_download=force,
|
||||||
|
Loading…
Reference in New Issue
Block a user