fix artifacts path

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-02-03 21:13:05 +01:00
parent 94751a78f4
commit 18aad34d67
6 changed files with 124 additions and 29 deletions

View File

@ -61,5 +61,7 @@ class AppSettings(BaseSettings):
perf: BatchConcurrencySettings
debug: DebugSettings
cache_dir: Path = Path.home() / ".cache" / "docling"
settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings())

View File

@ -61,13 +61,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
Processes the given batch of elements and enriches them with predictions.
"""
_model_repo_folder = "CodeFormula"
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
expansion_factor = 0.03
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Union[Path, str]],
artifacts_path: Optional[Path],
options: CodeFormulaModelOptions,
accelerator_options: AcceleratorOptions,
):
@ -98,7 +99,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
if artifacts_path is None:
artifacts_path = self.download_models_hf()
else:
artifacts_path = Path(artifacts_path)
artifacts_path = artifacts_path / self._model_repo_folder
self.code_formula_model = CodeFormulaPredictor(
artifacts_path=artifacts_path,

View File

@ -55,12 +55,13 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
Processes a batch of elements and adds classification annotations.
"""
_model_repo_folder = "DocumentFigureClassifier"
images_scale = 2
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Union[Path, str]],
artifacts_path: Optional[Path],
options: DocumentPictureClassifierOptions,
accelerator_options: AcceleratorOptions,
):
@ -90,7 +91,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
if artifacts_path is None:
artifacts_path = self.download_models_hf()
else:
artifacts_path = Path(artifacts_path)
artifacts_path = artifacts_path / self._model_repo_folder
self.document_picture_classifier = DocumentFigureClassifierPredictor(
artifacts_path=artifacts_path,

View File

@ -1,7 +1,8 @@
import copy
import logging
import warnings
from pathlib import Path
from typing import Iterable
from typing import Iterable, Optional, Union
from docling_core.types.doc import DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
@ -21,6 +22,8 @@ _log = logging.getLogger(__name__)
class LayoutModel(BasePageModel):
_model_repo_folder = "docling-models"
_model_path = "model_artifacts/layout"
TEXT_ELEM_LABELS = [
DocItemLabel.TEXT,
@ -42,15 +45,53 @@ class LayoutModel(BasePageModel):
FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
def __init__(
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions
):
device = decide_device(accelerator_options.device)
if artifacts_path is None:
artifacts_path = self.download_models_hf() / self._model_path
else:
# will become the default in the future
if (artifacts_path / self._model_repo_folder).exists():
artifacts_path = (
artifacts_path / self._model_repo_folder / self._model_path
)
elif (artifacts_path / self._model_path).exists():
warnings.warn(
"The usage of artifacts_path containing directly "
f"{self._model_path} is deprecated. Please point "
"the artifacts_path to the parent containing "
f"the {self._model_repo_folder} folder.",
DeprecationWarning,
stacklevel=3,
)
artifacts_path = artifacts_path / self._model_path
self.layout_predictor = LayoutPredictor(
artifact_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
)
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.1.0",
)
return Path(download_path)
def draw_clusters_and_cells_side_by_side(
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
):

View File

@ -1,6 +1,7 @@
import copy
import warnings
from pathlib import Path
from typing import Iterable
from typing import Iterable, Optional, Union
import numpy
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
@ -22,10 +23,13 @@ from docling.utils.profiling import TimeRecorder
class TableStructureModel(BasePageModel):
_model_repo_folder = "docling-models"
_model_path = "model_artifacts/tableformer"
def __init__(
self,
enabled: bool,
artifacts_path: Path,
artifacts_path: Optional[Path],
options: TableStructureOptions,
accelerator_options: AcceleratorOptions,
):
@ -35,6 +39,26 @@ class TableStructureModel(BasePageModel):
self.enabled = enabled
if self.enabled:
if artifacts_path is None:
artifacts_path = self.download_models_hf() / self._model_path
else:
# will become the default in the future
if (artifacts_path / self._model_repo_folder).exists():
artifacts_path = (
artifacts_path / self._model_repo_folder / self._model_path
)
elif (artifacts_path / self._model_path).exists():
warnings.warn(
"The usage of artifacts_path containing directly "
f"{self._model_path} is deprecated. Please point "
"the artifacts_path to the parent containing "
f"the {self._model_repo_folder} folder.",
DeprecationWarning,
stacklevel=3,
)
artifacts_path = artifacts_path / self._model_path
if self.mode == TableFormerMode.ACCURATE:
artifacts_path = artifacts_path / "accurate"
else:
@ -58,6 +82,23 @@ class TableStructureModel(BasePageModel):
)
self.scale = 2.0 # Scale up table input images to 144 dpi
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.1.0",
)
return Path(download_path)
def draw_table_and_cells(
self,
conv_res: ConversionResult,

View File

@ -17,6 +17,7 @@ from docling.datamodel.pipeline_options import (
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.document_picture_classifier import (
@ -43,17 +44,16 @@ _log = logging.getLogger(__name__)
class StandardPdfPipeline(PaginatedPipeline):
_layout_model_path = "model_artifacts/layout"
_table_model_path = "model_artifacts/tableformer"
_layout_model_path = LayoutModel._model_path
_table_model_path = TableStructureModel._model_path
def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions
if pipeline_options.artifacts_path is None:
self.artifacts_path = self.download_models_hf()
else:
self.artifacts_path = Path(pipeline_options.artifacts_path)
artifacts_path: Optional[Path] = None
if pipeline_options.artifacts_path is not None:
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
self.keep_images = (
self.pipeline_options.generate_page_images
@ -79,15 +79,13 @@ class StandardPdfPipeline(PaginatedPipeline):
ocr_model,
# Layout model
LayoutModel(
artifacts_path=self.artifacts_path
/ StandardPdfPipeline._layout_model_path,
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
),
# Table structure model
TableStructureModel(
enabled=pipeline_options.do_table_structure,
artifacts_path=self.artifacts_path
/ StandardPdfPipeline._table_model_path,
artifacts_path=artifacts_path,
options=pipeline_options.table_structure_options,
accelerator_options=pipeline_options.accelerator_options,
),
@ -101,7 +99,7 @@ class StandardPdfPipeline(PaginatedPipeline):
CodeFormulaModel(
enabled=pipeline_options.do_code_enrichment
or pipeline_options.do_formula_enrichment,
artifacts_path=pipeline_options.artifacts_path,
artifacts_path=artifacts_path,
options=CodeFormulaModelOptions(
do_code_enrichment=pipeline_options.do_code_enrichment,
do_formula_enrichment=pipeline_options.do_formula_enrichment,
@ -111,7 +109,7 @@ class StandardPdfPipeline(PaginatedPipeline):
# Document Picture Classifier
DocumentPictureClassifier(
enabled=pipeline_options.do_picture_classification,
artifacts_path=pipeline_options.artifacts_path,
artifacts_path=artifacts_path,
options=DocumentPictureClassifierOptions(),
accelerator_options=pipeline_options.accelerator_options,
),
@ -127,18 +125,29 @@ class StandardPdfPipeline(PaginatedPipeline):
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.1.0",
if local_dir is None:
local_dir = settings.cache_dir / "models"
# Make sure the folder exists
local_dir.mkdir(exist_ok=True, parents=True)
# Download model weights
LayoutModel.download_models_hf(
local_dir=local_dir / LayoutModel._model_repo_folder, force=force
)
TableStructureModel.download_models_hf(
local_dir=local_dir / TableStructureModel._model_repo_folder, force=force
)
DocumentPictureClassifier.download_models_hf(
local_dir=local_dir / DocumentPictureClassifier._model_repo_folder,
force=force,
)
CodeFormulaModel.download_models_hf(
local_dir=local_dir / CodeFormulaModel._model_repo_folder, force=force
)
return Path(download_path)
return local_dir
def get_ocr_model(self) -> Optional[BaseOcrModel]:
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):