Pydantic field validator and comment restored.

Signed-off-by: ahn <ahn@zurich.ibm.com>
This commit is contained in:
ahn 2025-01-31 16:21:36 +01:00
parent b9668877be
commit dd0728e646
3 changed files with 14 additions and 41 deletions

View File

@ -41,7 +41,7 @@ class AcceleratorOptions(BaseSettings):
num_threads: int = 4
device: str = "auto"
@validator("device")
@field_validator("device")
def validate_device(cls, value):
# "auto", "cpu", "cuda", "mps", or "cuda:N"
if value in {d.value for d in AcceleratorDevice} or re.match(
@ -55,6 +55,15 @@ class AcceleratorOptions(BaseSettings):
@model_validator(mode="before")
@classmethod
def check_alternative_envvars(cls, data: Any) -> Any:
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
the same functionality. In case the alias envvar is set and the user tries to override the
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
if input_num_threads is None:

View File

@ -18,7 +18,6 @@ from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
from docling.utils.utils import download_url_with_progress
_log = logging.getLogger(__name__)
@ -56,6 +55,7 @@ class EasyOcrModel(BaseOcrModel):
for x in [
AcceleratorDevice.CUDA.value,
AcceleratorDevice.MPS.value,
"cuda:",
]
]
)
@ -82,40 +82,6 @@ class EasyOcrModel(BaseOcrModel):
verbose=False,
)
@staticmethod
def download_models(
detection_models: List[str] = ["craft"],
recognition_models: List[str] = ["english_g2", "latin_g2"],
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
from easyocr.config import detection_models as det_models_dict
from easyocr.config import recognition_models as rec_models_dict
if local_dir is None:
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
local_dir.mkdir(parents=True, exist_ok=True)
# Collect models to download
download_list = []
for model_name in detection_models:
if model_name in det_models_dict:
download_list.append(det_models_dict[model_name])
for model_name in recognition_models:
if model_name in rec_models_dict["gen2"]:
download_list.append(rec_models_dict["gen2"][model_name])
# Download models
for model_details in download_list:
buf = download_url_with_progress(model_details["url"], progress=progress)
with zipfile.ZipFile(buf, "r") as zip_ref:
zip_ref.extractall(local_dir)
return local_dir
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:

View File

@ -31,9 +31,7 @@ def main():
# )
# easyocr doesnt support cuda:N allocation
# accelerator_options = AcceleratorOptions(
# num_threads=8, device="cuda:1"
# )
# accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:0")
pipeline_options = PdfPipelineOptions()
pipeline_options.accelerator_options = accelerator_options
@ -59,8 +57,8 @@ def main():
# List with total time per document
doc_conversion_secs = conversion_result.timings["pipeline_total"].times
md = doc.export_to_markdown()
print(md)
# md = doc.export_to_markdown()
# print(md)
print(f"Conversion secs: {doc_conversion_secs}")