mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-16 16:48:21 +00:00
feat: new artifacts path and CLI utility (#876)
* fix artifacts path Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add docling-models utility Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * missing formatting Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename utility to docling-tools Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename download methods and deprecation warnings Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * propagate artifacts path usage for ocr models Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move function to utils Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove unused file Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * update docs Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * simplify downloading specific model(s) Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> * minor refactor Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
This commit is contained in:
@@ -61,6 +61,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
"""
|
||||
|
||||
_model_repo_folder = "CodeFormula"
|
||||
elements_batch_size = 5
|
||||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
||||
expansion_factor = 0.03
|
||||
@@ -68,7 +69,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
artifacts_path: Optional[Path],
|
||||
options: CodeFormulaModelOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
@@ -97,9 +98,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
)
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models_hf()
|
||||
artifacts_path = self.download_models()
|
||||
else:
|
||||
artifacts_path = Path(artifacts_path)
|
||||
artifacts_path = artifacts_path / self._model_repo_folder
|
||||
|
||||
self.code_formula_model = CodeFormulaPredictor(
|
||||
artifacts_path=artifacts_path,
|
||||
@@ -108,13 +109,16 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
def download_models(
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
if not progress:
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/CodeFormula",
|
||||
force_download=force,
|
||||
|
||||
@@ -55,12 +55,13 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
||||
Processes a batch of elements and adds classification annotations.
|
||||
"""
|
||||
|
||||
_model_repo_folder = "DocumentFigureClassifier"
|
||||
images_scale = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
artifacts_path: Optional[Path],
|
||||
options: DocumentPictureClassifierOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
@@ -88,9 +89,9 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
||||
)
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models_hf()
|
||||
artifacts_path = self.download_models()
|
||||
else:
|
||||
artifacts_path = Path(artifacts_path)
|
||||
artifacts_path = artifacts_path / self._model_repo_folder
|
||||
|
||||
self.document_picture_classifier = DocumentFigureClassifierPredictor(
|
||||
artifacts_path=artifacts_path,
|
||||
@@ -99,13 +100,14 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
def download_models(
|
||||
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
if not progress:
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/DocumentFigureClassifier",
|
||||
force_download=force,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Iterable
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import httpx
|
||||
import numpy
|
||||
import torch
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
@@ -17,14 +20,18 @@ from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
from docling.utils.utils import download_url_with_progress
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyOcrModel(BaseOcrModel):
|
||||
_model_repo_folder = "EasyOcr"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Path],
|
||||
options: EasyOcrOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
@@ -62,15 +69,55 @@ class EasyOcrModel(BaseOcrModel):
|
||||
)
|
||||
use_gpu = self.options.use_gpu
|
||||
|
||||
download_enabled = self.options.download_enabled
|
||||
model_storage_directory = self.options.model_storage_directory
|
||||
if artifacts_path is not None and model_storage_directory is None:
|
||||
download_enabled = False
|
||||
model_storage_directory = str(artifacts_path / self._model_repo_folder)
|
||||
|
||||
self.reader = easyocr.Reader(
|
||||
lang_list=self.options.lang,
|
||||
gpu=use_gpu,
|
||||
model_storage_directory=self.options.model_storage_directory,
|
||||
model_storage_directory=model_storage_directory,
|
||||
recog_network=self.options.recog_network,
|
||||
download_enabled=self.options.download_enabled,
|
||||
download_enabled=download_enabled,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
detection_models: List[str] = ["craft"],
|
||||
recognition_models: List[str] = ["english_g2", "latin_g2"],
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
|
||||
from easyocr.config import detection_models as det_models_dict
|
||||
from easyocr.config import recognition_models as rec_models_dict
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
|
||||
|
||||
local_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Collect models to download
|
||||
download_list = []
|
||||
for model_name in detection_models:
|
||||
if model_name in det_models_dict:
|
||||
download_list.append(det_models_dict[model_name])
|
||||
for model_name in recognition_models:
|
||||
if model_name in rec_models_dict["gen2"]:
|
||||
download_list.append(rec_models_dict["gen2"][model_name])
|
||||
|
||||
# Download models
|
||||
for model_details in download_list:
|
||||
buf = download_url_with_progress(model_details["url"], progress=progress)
|
||||
with zipfile.ZipFile(buf, "r") as zip_ref:
|
||||
zip_ref.extractall(local_dir)
|
||||
|
||||
return local_dir
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from docling_core.types.doc import DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
@@ -21,6 +22,8 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayoutModel(BasePageModel):
|
||||
_model_repo_folder = "docling-models"
|
||||
_model_path = "model_artifacts/layout"
|
||||
|
||||
TEXT_ELEM_LABELS = [
|
||||
DocItemLabel.TEXT,
|
||||
@@ -42,15 +45,56 @@ class LayoutModel(BasePageModel):
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||
|
||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||
def __init__(
|
||||
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions
|
||||
):
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models() / self._model_path
|
||||
else:
|
||||
# will become the default in the future
|
||||
if (artifacts_path / self._model_repo_folder).exists():
|
||||
artifacts_path = (
|
||||
artifacts_path / self._model_repo_folder / self._model_path
|
||||
)
|
||||
elif (artifacts_path / self._model_path).exists():
|
||||
warnings.warn(
|
||||
"The usage of artifacts_path containing directly "
|
||||
f"{self._model_path} is deprecated. Please point "
|
||||
"the artifacts_path to the parent containing "
|
||||
f"the {self._model_repo_folder} folder.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
artifacts_path = artifacts_path / self._model_path
|
||||
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
artifact_path=str(artifacts_path),
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
if not progress:
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/docling-models",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v2.1.0",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def draw_clusters_and_cells_side_by_side(
|
||||
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
||||
):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
||||
@@ -22,10 +23,13 @@ from docling.utils.profiling import TimeRecorder
|
||||
|
||||
|
||||
class TableStructureModel(BasePageModel):
|
||||
_model_repo_folder = "docling-models"
|
||||
_model_path = "model_artifacts/tableformer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Path,
|
||||
artifacts_path: Optional[Path],
|
||||
options: TableStructureOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
@@ -35,6 +39,26 @@ class TableStructureModel(BasePageModel):
|
||||
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models() / self._model_path
|
||||
else:
|
||||
# will become the default in the future
|
||||
if (artifacts_path / self._model_repo_folder).exists():
|
||||
artifacts_path = (
|
||||
artifacts_path / self._model_repo_folder / self._model_path
|
||||
)
|
||||
elif (artifacts_path / self._model_path).exists():
|
||||
warnings.warn(
|
||||
"The usage of artifacts_path containing directly "
|
||||
f"{self._model_path} is deprecated. Please point "
|
||||
"the artifacts_path to the parent containing "
|
||||
f"the {self._model_repo_folder} folder.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
artifacts_path = artifacts_path / self._model_path
|
||||
|
||||
if self.mode == TableFormerMode.ACCURATE:
|
||||
artifacts_path = artifacts_path / "accurate"
|
||||
else:
|
||||
@@ -58,6 +82,24 @@ class TableStructureModel(BasePageModel):
|
||||
)
|
||||
self.scale = 2.0 # Scale up table input images to 144 dpi
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
if not progress:
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/docling-models",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v2.1.0",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def draw_table_and_cells(
|
||||
self,
|
||||
conv_res: ConversionResult,
|
||||
|
||||
Reference in New Issue
Block a user