diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 354b997d..aec6a5ac 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -41,7 +41,7 @@ class AcceleratorOptions(BaseSettings): num_threads: int = 4 device: str = "auto" - @validator("device") + @field_validator("device") def validate_device(cls, value): # "auto", "cpu", "cuda", "mps", or "cuda:N" if value in {d.value for d in AcceleratorDevice} or re.match( @@ -55,6 +55,15 @@ class AcceleratorOptions(BaseSettings): @model_validator(mode="before") @classmethod def check_alternative_envvars(cls, data: Any) -> Any: + r""" + Set num_threads from the "alternative" envvar OMP_NUM_THREADS. + The alternative envvar is used only if it is valid and the regular envvar is not set. + + Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide + the same functionality. In case the alias envvar is set and the user tries to override the + parameter in settings initialization, Pydantic treats the parameter provided in __init__() + as an extra input instead of simply overwriting the evvar value for that parameter. + """ if isinstance(data, dict): input_num_threads = data.get("num_threads") if input_num_threads is None: diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index 0eccb988..9aec872a 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -18,7 +18,6 @@ from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import TimeRecorder -from docling.utils.utils import download_url_with_progress _log = logging.getLogger(__name__) @@ -56,6 +55,7 @@ class EasyOcrModel(BaseOcrModel): for x in [ AcceleratorDevice.CUDA.value, AcceleratorDevice.MPS.value, + "cuda:", ] ] ) @@ -82,40 +82,6 @@ class EasyOcrModel(BaseOcrModel): verbose=False, ) - @staticmethod - def download_models( - detection_models: List[str] = ["craft"], - recognition_models: List[str] = ["english_g2", "latin_g2"], - local_dir: Optional[Path] = None, - force: bool = False, - progress: bool = False, - ) -> Path: - # Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py - from easyocr.config import detection_models as det_models_dict - from easyocr.config import recognition_models as rec_models_dict - - if local_dir is None: - local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder - - local_dir.mkdir(parents=True, exist_ok=True) - - # Collect models to download - download_list = [] - for model_name in detection_models: - if model_name in det_models_dict: - download_list.append(det_models_dict[model_name]) - for model_name in recognition_models: - if model_name in rec_models_dict["gen2"]: - download_list.append(rec_models_dict["gen2"][model_name]) - - # Download models - for model_details in download_list: - buf = download_url_with_progress(model_details["url"], progress=progress) - with zipfile.ZipFile(buf, "r") as zip_ref: - zip_ref.extractall(local_dir) - - return local_dir - def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: diff --git a/docs/examples/run_with_accelerator.py b/docs/examples/run_with_accelerator.py index fff4accc..aca19a03 100644 --- a/docs/examples/run_with_accelerator.py +++ b/docs/examples/run_with_accelerator.py @@ -31,9 +31,7 @@ def main(): # ) # easyocr doesnt support cuda:N allocation - # accelerator_options = AcceleratorOptions( - # num_threads=8, device="cuda:1" - # ) + # accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:0") pipeline_options = PdfPipelineOptions() pipeline_options.accelerator_options = accelerator_options @@ -59,8 +57,8 @@ def main(): # List with total time per document doc_conversion_secs = conversion_result.timings["pipeline_total"].times - md = doc.export_to_markdown() - print(md) + # md = doc.export_to_markdown() + # print(md) print(f"Conversion secs: {doc_conversion_secs}")