From d4eee87b267e9cc58672198583c7311c9f2f7289 Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Fri, 7 Feb 2025 13:59:24 +0100 Subject: [PATCH] apply CLI download login Signed-off-by: Michele Dolfi --- .github/workflows/checks.yml | 2 +- docling/datamodel/pipeline_options.py | 4 ++++ docling/models/picture_description_vlm_model.py | 2 ++ docling/pipeline/standard_pdf_pipeline.py | 6 +++++- docling/utils/model_downloader.py | 12 ++++++++++++ 5 files changed, 24 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 6240e9a8..89bcfd79 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -28,7 +28,7 @@ jobs: run: | for file in docs/examples/*.py; do # Skip batch_convert.py - if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description_api).py ]]; then + if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then echo "Skipping $file" continue fi diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index d6394e33..3b6401b6 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -214,6 +214,10 @@ class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions): # Config from here https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig generation_config: Dict[str, Any] = dict(max_new_tokens=200, do_sample=False) + @property + def repo_cache_folder(self) -> str: + return self.repo_id.replace("/", "--") + smolvlm_picture_description = PictureDescriptionVlmOptions( repo_id="HuggingFaceTB/SmolVLM-256M-Instruct" diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py index b4c8cf21..9fa4826d 100644 --- a/docling/models/picture_description_vlm_model.py +++ b/docling/models/picture_description_vlm_model.py @@ -27,6 +27,8 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel): if artifacts_path is None: artifacts_path = self.download_models(repo_id=self.options.repo_id) + else: + artifacts_path = Path(artifacts_path) / self.options.repo_cache_folder self.device = decide_device(accelerator_options.device) diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 3a329525..13e435f9 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -101,7 +101,11 @@ class StandardPdfPipeline(PaginatedPipeline): ] # Picture description model - if (picture_description_model := self.get_picture_description_model()) is None: + if ( + picture_description_model := self.get_picture_description_model( + artifacts_path=artifacts_path + ) + ) is None: raise RuntimeError( f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}." ) diff --git a/docling/utils/model_downloader.py b/docling/utils/model_downloader.py index 504618ec..7d22b77b 100644 --- a/docling/utils/model_downloader.py +++ b/docling/utils/model_downloader.py @@ -2,11 +2,13 @@ import logging from pathlib import Path from typing import Optional +from docling.datamodel.pipeline_options import smolvlm_picture_description from docling.datamodel.settings import settings from docling.models.code_formula_model import CodeFormulaModel from docling.models.document_picture_classifier import DocumentPictureClassifier from docling.models.easyocr_model import EasyOcrModel from docling.models.layout_model import LayoutModel +from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel from docling.models.table_structure_model import TableStructureModel _log = logging.getLogger(__name__) @@ -21,6 +23,7 @@ def download_models( with_tableformer: bool = True, with_code_formula: bool = True, with_picture_classifier: bool = True, + with_smolvlm: bool = True, with_easyocr: bool = True, ): if output_dir is None: @@ -61,6 +64,15 @@ def download_models( progress=progress, ) + if with_smolvlm: + _log.info(f"Downloading SmolVlm model...") + PictureDescriptionVlmModel.download_models( + repo_id=smolvlm_picture_description.repo_id, + local_dir=output_dir / smolvlm_picture_description.repo_cache_folder, + force=force, + progress=progress, + ) + if with_easyocr: _log.info(f"Downloading easyocr models...") EasyOcrModel.download_models(