From 18aad34d6767f20a1c2113a61222952084daacac Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Mon, 3 Feb 2025 21:13:05 +0100 Subject: [PATCH] fix artifacts path Signed-off-by: Michele Dolfi --- docling/datamodel/settings.py | 2 + docling/models/code_formula_model.py | 5 +- docling/models/document_picture_classifier.py | 5 +- docling/models/layout_model.py | 45 +++++++++++++++- docling/models/table_structure_model.py | 45 +++++++++++++++- docling/pipeline/standard_pdf_pipeline.py | 51 +++++++++++-------- 6 files changed, 124 insertions(+), 29 deletions(-) diff --git a/docling/datamodel/settings.py b/docling/datamodel/settings.py index 92856203..439ffe74 100644 --- a/docling/datamodel/settings.py +++ b/docling/datamodel/settings.py @@ -61,5 +61,7 @@ class AppSettings(BaseSettings): perf: BatchConcurrencySettings debug: DebugSettings + cache_dir: Path = Path.home() / ".cache" / "docling" + settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings()) diff --git a/docling/models/code_formula_model.py b/docling/models/code_formula_model.py index e4d56945..2e380c6c 100644 --- a/docling/models/code_formula_model.py +++ b/docling/models/code_formula_model.py @@ -61,13 +61,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel): Processes the given batch of elements and enriches them with predictions. """ + _model_repo_folder = "CodeFormula" images_scale = 1.66 # = 120 dpi, aligned with training data resolution expansion_factor = 0.03 def __init__( self, enabled: bool, - artifacts_path: Optional[Union[Path, str]], + artifacts_path: Optional[Path], options: CodeFormulaModelOptions, accelerator_options: AcceleratorOptions, ): @@ -98,7 +99,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel): if artifacts_path is None: artifacts_path = self.download_models_hf() else: - artifacts_path = Path(artifacts_path) + artifacts_path = artifacts_path / self._model_repo_folder self.code_formula_model = CodeFormulaPredictor( artifacts_path=artifacts_path, diff --git a/docling/models/document_picture_classifier.py b/docling/models/document_picture_classifier.py index 6e2d90b4..ff981d92 100644 --- a/docling/models/document_picture_classifier.py +++ b/docling/models/document_picture_classifier.py @@ -55,12 +55,13 @@ class DocumentPictureClassifier(BaseEnrichmentModel): Processes a batch of elements and adds classification annotations. """ + _model_repo_folder = "DocumentFigureClassifier" images_scale = 2 def __init__( self, enabled: bool, - artifacts_path: Optional[Union[Path, str]], + artifacts_path: Optional[Path], options: DocumentPictureClassifierOptions, accelerator_options: AcceleratorOptions, ): @@ -90,7 +91,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel): if artifacts_path is None: artifacts_path = self.download_models_hf() else: - artifacts_path = Path(artifacts_path) + artifacts_path = artifacts_path / self._model_repo_folder self.document_picture_classifier = DocumentFigureClassifierPredictor( artifacts_path=artifacts_path, diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 69193c94..2330cc29 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -1,7 +1,8 @@ import copy import logging +import warnings from pathlib import Path -from typing import Iterable +from typing import Iterable, Optional, Union from docling_core.types.doc import DocItemLabel from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor @@ -21,6 +22,8 @@ _log = logging.getLogger(__name__) class LayoutModel(BasePageModel): + _model_repo_folder = "docling-models" + _model_path = "model_artifacts/layout" TEXT_ELEM_LABELS = [ DocItemLabel.TEXT, @@ -42,15 +45,53 @@ class LayoutModel(BasePageModel): FORMULA_LABEL = DocItemLabel.FORMULA CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] - def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions): + def __init__( + self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions + ): device = decide_device(accelerator_options.device) + if artifacts_path is None: + artifacts_path = self.download_models_hf() / self._model_path + else: + # will become the default in the future + if (artifacts_path / self._model_repo_folder).exists(): + artifacts_path = ( + artifacts_path / self._model_repo_folder / self._model_path + ) + elif (artifacts_path / self._model_path).exists(): + warnings.warn( + "The usage of artifacts_path containing directly " + f"{self._model_path} is deprecated. Please point " + "the artifacts_path to the parent containing " + f"the {self._model_repo_folder} folder.", + DeprecationWarning, + stacklevel=3, + ) + artifacts_path = artifacts_path / self._model_path + self.layout_predictor = LayoutPredictor( artifact_path=str(artifacts_path), device=device, num_threads=accelerator_options.num_threads, ) + @staticmethod + def download_models_hf( + local_dir: Optional[Path] = None, force: bool = False + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + disable_progress_bars() + download_path = snapshot_download( + repo_id="ds4sd/docling-models", + force_download=force, + local_dir=local_dir, + revision="v2.1.0", + ) + + return Path(download_path) + def draw_clusters_and_cells_side_by_side( self, conv_res, page, clusters, mode_prefix: str, show: bool = False ): diff --git a/docling/models/table_structure_model.py b/docling/models/table_structure_model.py index f17cbed0..297b6c2e 100644 --- a/docling/models/table_structure_model.py +++ b/docling/models/table_structure_model.py @@ -1,6 +1,7 @@ import copy +import warnings from pathlib import Path -from typing import Iterable +from typing import Iterable, Optional, Union import numpy from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell @@ -22,10 +23,13 @@ from docling.utils.profiling import TimeRecorder class TableStructureModel(BasePageModel): + _model_repo_folder = "docling-models" + _model_path = "model_artifacts/tableformer" + def __init__( self, enabled: bool, - artifacts_path: Path, + artifacts_path: Optional[Path], options: TableStructureOptions, accelerator_options: AcceleratorOptions, ): @@ -35,6 +39,26 @@ class TableStructureModel(BasePageModel): self.enabled = enabled if self.enabled: + + if artifacts_path is None: + artifacts_path = self.download_models_hf() / self._model_path + else: + # will become the default in the future + if (artifacts_path / self._model_repo_folder).exists(): + artifacts_path = ( + artifacts_path / self._model_repo_folder / self._model_path + ) + elif (artifacts_path / self._model_path).exists(): + warnings.warn( + "The usage of artifacts_path containing directly " + f"{self._model_path} is deprecated. Please point " + "the artifacts_path to the parent containing " + f"the {self._model_repo_folder} folder.", + DeprecationWarning, + stacklevel=3, + ) + artifacts_path = artifacts_path / self._model_path + if self.mode == TableFormerMode.ACCURATE: artifacts_path = artifacts_path / "accurate" else: @@ -58,6 +82,23 @@ class TableStructureModel(BasePageModel): ) self.scale = 2.0 # Scale up table input images to 144 dpi + @staticmethod + def download_models_hf( + local_dir: Optional[Path] = None, force: bool = False + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + disable_progress_bars() + download_path = snapshot_download( + repo_id="ds4sd/docling-models", + force_download=force, + local_dir=local_dir, + revision="v2.1.0", + ) + + return Path(download_path) + def draw_table_and_cells( self, conv_res: ConversionResult, diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index fe2201d6..f06bf5c6 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -17,6 +17,7 @@ from docling.datamodel.pipeline_options import ( TesseractCliOcrOptions, TesseractOcrOptions, ) +from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions from docling.models.document_picture_classifier import ( @@ -43,17 +44,16 @@ _log = logging.getLogger(__name__) class StandardPdfPipeline(PaginatedPipeline): - _layout_model_path = "model_artifacts/layout" - _table_model_path = "model_artifacts/tableformer" + _layout_model_path = LayoutModel._model_path + _table_model_path = TableStructureModel._model_path def __init__(self, pipeline_options: PdfPipelineOptions): super().__init__(pipeline_options) self.pipeline_options: PdfPipelineOptions - if pipeline_options.artifacts_path is None: - self.artifacts_path = self.download_models_hf() - else: - self.artifacts_path = Path(pipeline_options.artifacts_path) + artifacts_path: Optional[Path] = None + if pipeline_options.artifacts_path is not None: + artifacts_path = Path(pipeline_options.artifacts_path).expanduser() self.keep_images = ( self.pipeline_options.generate_page_images @@ -79,15 +79,13 @@ class StandardPdfPipeline(PaginatedPipeline): ocr_model, # Layout model LayoutModel( - artifacts_path=self.artifacts_path - / StandardPdfPipeline._layout_model_path, + artifacts_path=artifacts_path, accelerator_options=pipeline_options.accelerator_options, ), # Table structure model TableStructureModel( enabled=pipeline_options.do_table_structure, - artifacts_path=self.artifacts_path - / StandardPdfPipeline._table_model_path, + artifacts_path=artifacts_path, options=pipeline_options.table_structure_options, accelerator_options=pipeline_options.accelerator_options, ), @@ -101,7 +99,7 @@ class StandardPdfPipeline(PaginatedPipeline): CodeFormulaModel( enabled=pipeline_options.do_code_enrichment or pipeline_options.do_formula_enrichment, - artifacts_path=pipeline_options.artifacts_path, + artifacts_path=artifacts_path, options=CodeFormulaModelOptions( do_code_enrichment=pipeline_options.do_code_enrichment, do_formula_enrichment=pipeline_options.do_formula_enrichment, @@ -111,7 +109,7 @@ class StandardPdfPipeline(PaginatedPipeline): # Document Picture Classifier DocumentPictureClassifier( enabled=pipeline_options.do_picture_classification, - artifacts_path=pipeline_options.artifacts_path, + artifacts_path=artifacts_path, options=DocumentPictureClassifierOptions(), accelerator_options=pipeline_options.accelerator_options, ), @@ -127,18 +125,29 @@ class StandardPdfPipeline(PaginatedPipeline): def download_models_hf( local_dir: Optional[Path] = None, force: bool = False ) -> Path: - from huggingface_hub import snapshot_download - from huggingface_hub.utils import disable_progress_bars - disable_progress_bars() - download_path = snapshot_download( - repo_id="ds4sd/docling-models", - force_download=force, - local_dir=local_dir, - revision="v2.1.0", + if local_dir is None: + local_dir = settings.cache_dir / "models" + + # Make sure the folder exists + local_dir.mkdir(exist_ok=True, parents=True) + + # Download model weights + LayoutModel.download_models_hf( + local_dir=local_dir / LayoutModel._model_repo_folder, force=force + ) + TableStructureModel.download_models_hf( + local_dir=local_dir / TableStructureModel._model_repo_folder, force=force + ) + DocumentPictureClassifier.download_models_hf( + local_dir=local_dir / DocumentPictureClassifier._model_repo_folder, + force=force, + ) + CodeFormulaModel.download_models_hf( + local_dir=local_dir / CodeFormulaModel._model_repo_folder, force=force ) - return Path(download_path) + return local_dir def get_ocr_model(self) -> Optional[BaseOcrModel]: if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):