fix artifacts path

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-02-03 21:13:05 +01:00
parent 94751a78f4
commit 18aad34d67
6 changed files with 124 additions and 29 deletions

View File

@ -61,5 +61,7 @@ class AppSettings(BaseSettings):
perf: BatchConcurrencySettings perf: BatchConcurrencySettings
debug: DebugSettings debug: DebugSettings
cache_dir: Path = Path.home() / ".cache" / "docling"
settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings()) settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings())

View File

@ -61,13 +61,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
Processes the given batch of elements and enriches them with predictions. Processes the given batch of elements and enriches them with predictions.
""" """
_model_repo_folder = "CodeFormula"
images_scale = 1.66 # = 120 dpi, aligned with training data resolution images_scale = 1.66 # = 120 dpi, aligned with training data resolution
expansion_factor = 0.03 expansion_factor = 0.03
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
artifacts_path: Optional[Union[Path, str]], artifacts_path: Optional[Path],
options: CodeFormulaModelOptions, options: CodeFormulaModelOptions,
accelerator_options: AcceleratorOptions, accelerator_options: AcceleratorOptions,
): ):
@ -98,7 +99,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
if artifacts_path is None: if artifacts_path is None:
artifacts_path = self.download_models_hf() artifacts_path = self.download_models_hf()
else: else:
artifacts_path = Path(artifacts_path) artifacts_path = artifacts_path / self._model_repo_folder
self.code_formula_model = CodeFormulaPredictor( self.code_formula_model = CodeFormulaPredictor(
artifacts_path=artifacts_path, artifacts_path=artifacts_path,

View File

@ -55,12 +55,13 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
Processes a batch of elements and adds classification annotations. Processes a batch of elements and adds classification annotations.
""" """
_model_repo_folder = "DocumentFigureClassifier"
images_scale = 2 images_scale = 2
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
artifacts_path: Optional[Union[Path, str]], artifacts_path: Optional[Path],
options: DocumentPictureClassifierOptions, options: DocumentPictureClassifierOptions,
accelerator_options: AcceleratorOptions, accelerator_options: AcceleratorOptions,
): ):
@ -90,7 +91,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
if artifacts_path is None: if artifacts_path is None:
artifacts_path = self.download_models_hf() artifacts_path = self.download_models_hf()
else: else:
artifacts_path = Path(artifacts_path) artifacts_path = artifacts_path / self._model_repo_folder
self.document_picture_classifier = DocumentFigureClassifierPredictor( self.document_picture_classifier = DocumentFigureClassifierPredictor(
artifacts_path=artifacts_path, artifacts_path=artifacts_path,

View File

@ -1,7 +1,8 @@
import copy import copy
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from typing import Iterable from typing import Iterable, Optional, Union
from docling_core.types.doc import DocItemLabel from docling_core.types.doc import DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
@ -21,6 +22,8 @@ _log = logging.getLogger(__name__)
class LayoutModel(BasePageModel): class LayoutModel(BasePageModel):
_model_repo_folder = "docling-models"
_model_path = "model_artifacts/layout"
TEXT_ELEM_LABELS = [ TEXT_ELEM_LABELS = [
DocItemLabel.TEXT, DocItemLabel.TEXT,
@ -42,15 +45,53 @@ class LayoutModel(BasePageModel):
FORMULA_LABEL = DocItemLabel.FORMULA FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions): def __init__(
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions
):
device = decide_device(accelerator_options.device) device = decide_device(accelerator_options.device)
if artifacts_path is None:
artifacts_path = self.download_models_hf() / self._model_path
else:
# will become the default in the future
if (artifacts_path / self._model_repo_folder).exists():
artifacts_path = (
artifacts_path / self._model_repo_folder / self._model_path
)
elif (artifacts_path / self._model_path).exists():
warnings.warn(
"The usage of artifacts_path containing directly "
f"{self._model_path} is deprecated. Please point "
"the artifacts_path to the parent containing "
f"the {self._model_repo_folder} folder.",
DeprecationWarning,
stacklevel=3,
)
artifacts_path = artifacts_path / self._model_path
self.layout_predictor = LayoutPredictor( self.layout_predictor = LayoutPredictor(
artifact_path=str(artifacts_path), artifact_path=str(artifacts_path),
device=device, device=device,
num_threads=accelerator_options.num_threads, num_threads=accelerator_options.num_threads,
) )
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.1.0",
)
return Path(download_path)
def draw_clusters_and_cells_side_by_side( def draw_clusters_and_cells_side_by_side(
self, conv_res, page, clusters, mode_prefix: str, show: bool = False self, conv_res, page, clusters, mode_prefix: str, show: bool = False
): ):

View File

@ -1,6 +1,7 @@
import copy import copy
import warnings
from pathlib import Path from pathlib import Path
from typing import Iterable from typing import Iterable, Optional, Union
import numpy import numpy
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
@ -22,10 +23,13 @@ from docling.utils.profiling import TimeRecorder
class TableStructureModel(BasePageModel): class TableStructureModel(BasePageModel):
_model_repo_folder = "docling-models"
_model_path = "model_artifacts/tableformer"
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
artifacts_path: Path, artifacts_path: Optional[Path],
options: TableStructureOptions, options: TableStructureOptions,
accelerator_options: AcceleratorOptions, accelerator_options: AcceleratorOptions,
): ):
@ -35,6 +39,26 @@ class TableStructureModel(BasePageModel):
self.enabled = enabled self.enabled = enabled
if self.enabled: if self.enabled:
if artifacts_path is None:
artifacts_path = self.download_models_hf() / self._model_path
else:
# will become the default in the future
if (artifacts_path / self._model_repo_folder).exists():
artifacts_path = (
artifacts_path / self._model_repo_folder / self._model_path
)
elif (artifacts_path / self._model_path).exists():
warnings.warn(
"The usage of artifacts_path containing directly "
f"{self._model_path} is deprecated. Please point "
"the artifacts_path to the parent containing "
f"the {self._model_repo_folder} folder.",
DeprecationWarning,
stacklevel=3,
)
artifacts_path = artifacts_path / self._model_path
if self.mode == TableFormerMode.ACCURATE: if self.mode == TableFormerMode.ACCURATE:
artifacts_path = artifacts_path / "accurate" artifacts_path = artifacts_path / "accurate"
else: else:
@ -58,6 +82,23 @@ class TableStructureModel(BasePageModel):
) )
self.scale = 2.0 # Scale up table input images to 144 dpi self.scale = 2.0 # Scale up table input images to 144 dpi
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.1.0",
)
return Path(download_path)
def draw_table_and_cells( def draw_table_and_cells(
self, self,
conv_res: ConversionResult, conv_res: ConversionResult,

View File

@ -17,6 +17,7 @@ from docling.datamodel.pipeline_options import (
TesseractCliOcrOptions, TesseractCliOcrOptions,
TesseractOcrOptions, TesseractOcrOptions,
) )
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.document_picture_classifier import ( from docling.models.document_picture_classifier import (
@ -43,17 +44,16 @@ _log = logging.getLogger(__name__)
class StandardPdfPipeline(PaginatedPipeline): class StandardPdfPipeline(PaginatedPipeline):
_layout_model_path = "model_artifacts/layout" _layout_model_path = LayoutModel._model_path
_table_model_path = "model_artifacts/tableformer" _table_model_path = TableStructureModel._model_path
def __init__(self, pipeline_options: PdfPipelineOptions): def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options) super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions self.pipeline_options: PdfPipelineOptions
if pipeline_options.artifacts_path is None: artifacts_path: Optional[Path] = None
self.artifacts_path = self.download_models_hf() if pipeline_options.artifacts_path is not None:
else: artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
self.artifacts_path = Path(pipeline_options.artifacts_path)
self.keep_images = ( self.keep_images = (
self.pipeline_options.generate_page_images self.pipeline_options.generate_page_images
@ -79,15 +79,13 @@ class StandardPdfPipeline(PaginatedPipeline):
ocr_model, ocr_model,
# Layout model # Layout model
LayoutModel( LayoutModel(
artifacts_path=self.artifacts_path artifacts_path=artifacts_path,
/ StandardPdfPipeline._layout_model_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
), ),
# Table structure model # Table structure model
TableStructureModel( TableStructureModel(
enabled=pipeline_options.do_table_structure, enabled=pipeline_options.do_table_structure,
artifacts_path=self.artifacts_path artifacts_path=artifacts_path,
/ StandardPdfPipeline._table_model_path,
options=pipeline_options.table_structure_options, options=pipeline_options.table_structure_options,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
), ),
@ -101,7 +99,7 @@ class StandardPdfPipeline(PaginatedPipeline):
CodeFormulaModel( CodeFormulaModel(
enabled=pipeline_options.do_code_enrichment enabled=pipeline_options.do_code_enrichment
or pipeline_options.do_formula_enrichment, or pipeline_options.do_formula_enrichment,
artifacts_path=pipeline_options.artifacts_path, artifacts_path=artifacts_path,
options=CodeFormulaModelOptions( options=CodeFormulaModelOptions(
do_code_enrichment=pipeline_options.do_code_enrichment, do_code_enrichment=pipeline_options.do_code_enrichment,
do_formula_enrichment=pipeline_options.do_formula_enrichment, do_formula_enrichment=pipeline_options.do_formula_enrichment,
@ -111,7 +109,7 @@ class StandardPdfPipeline(PaginatedPipeline):
# Document Picture Classifier # Document Picture Classifier
DocumentPictureClassifier( DocumentPictureClassifier(
enabled=pipeline_options.do_picture_classification, enabled=pipeline_options.do_picture_classification,
artifacts_path=pipeline_options.artifacts_path, artifacts_path=artifacts_path,
options=DocumentPictureClassifierOptions(), options=DocumentPictureClassifierOptions(),
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,
), ),
@ -127,18 +125,29 @@ class StandardPdfPipeline(PaginatedPipeline):
def download_models_hf( def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False local_dir: Optional[Path] = None, force: bool = False
) -> Path: ) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars() if local_dir is None:
download_path = snapshot_download( local_dir = settings.cache_dir / "models"
repo_id="ds4sd/docling-models",
force_download=force, # Make sure the folder exists
local_dir=local_dir, local_dir.mkdir(exist_ok=True, parents=True)
revision="v2.1.0",
# Download model weights
LayoutModel.download_models_hf(
local_dir=local_dir / LayoutModel._model_repo_folder, force=force
)
TableStructureModel.download_models_hf(
local_dir=local_dir / TableStructureModel._model_repo_folder, force=force
)
DocumentPictureClassifier.download_models_hf(
local_dir=local_dir / DocumentPictureClassifier._model_repo_folder,
force=force,
)
CodeFormulaModel.download_models_hf(
local_dir=local_dir / CodeFormulaModel._model_repo_folder, force=force
) )
return Path(download_path) return local_dir
def get_ocr_model(self) -> Optional[BaseOcrModel]: def get_ocr_model(self) -> Optional[BaseOcrModel]:
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions): if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):