move function to utils

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-02-05 18:20:26 +01:00
parent a11fd1f157
commit 1b8adb860a
3 changed files with 101 additions and 76 deletions

View File

@ -12,12 +12,14 @@ from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel from docling.models.layout_model import LayoutModel
from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.table_structure_model import TableStructureModel from docling.models.table_structure_model import TableStructureModel
from docling.utils.models_downloader import download_all
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch") warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr") warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
from rich.console import Console from rich.console import Console
from rich.logging import RichHandler
console = Console() console = Console()
err_console = Console(stderr=True) err_console = Console(stderr=True)
@ -85,60 +87,33 @@ def download(
typer.Option(..., help="If true, the rapidocr model weights are downloaded."), typer.Option(..., help="If true, the rapidocr model weights are downloaded."),
] = True, ] = True,
): ):
# Make sure the folder exists
output_dir.mkdir(exist_ok=True, parents=True)
show_progress = not quite
if layout:
if not quite: if not quite:
typer.secho(f"Downloading layout model...", fg="blue") FORMAT = "%(message)s"
LayoutModel.download_models( logging.basicConfig(
local_dir=output_dir / LayoutModel._model_repo_folder, level=logging.INFO,
force=force, format="[blue]%(message)s[/blue]",
progress=show_progress, datefmt="[%X]",
handlers=[RichHandler(show_level=False, show_time=False, markup=True)],
) )
if tableformer: output_dir = download_all(
if not quite: output_dir=output_dir,
typer.secho(f"Downloading tableformer model...", fg="blue")
TableStructureModel.download_models(
local_dir=output_dir / TableStructureModel._model_repo_folder,
force=force, force=force,
progress=show_progress, progress=(not quite),
) layout=layout,
tableformer=tableformer,
if picture_classifier: code_formula=code_formula,
if not quite: picture_classifier=picture_classifier,
typer.secho(f"Downloading picture classifier model...", fg="blue") easyocr=easyocr,
DocumentPictureClassifier.download_models( rapidocr=rapidocr,
local_dir=output_dir / DocumentPictureClassifier._model_repo_folder,
force=force,
progress=show_progress,
)
if code_formula:
if not quite:
typer.secho(f"Downloading code formula model...", fg="blue")
CodeFormulaModel.download_models(
local_dir=output_dir / CodeFormulaModel._model_repo_folder,
force=force,
progress=show_progress,
)
if easyocr:
if not quite:
typer.secho(f"Downloading easyocr models...", fg="blue")
EasyOcrModel.download_models(
local_dir=output_dir / EasyOcrModel._model_repo_folder,
force=force,
progress=show_progress,
) )
if quite: if quite:
typer.echo(output_dir) typer.echo(output_dir)
else: else:
typer.secho(f"All models downloaded in the directory {output_dir}.", fg="green") typer.secho(
f"\nAll models downloaded in the directory {output_dir}.", fg="green"
)
console.print( console.print(
"\n", "\n",
@ -150,11 +125,6 @@ def download(
) )
# @app.command(hidden=True)
# def other():
# raise NotImplementedError()
click_app = typer.main.get_command(app) click_app = typer.main.get_command(app)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -39,6 +39,7 @@ from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel from docling.models.tesseract_ocr_model import TesseractOcrModel
from docling.pipeline.base_pipeline import PaginatedPipeline from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.models_downloader import download_all
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -129,33 +130,13 @@ class StandardPdfPipeline(PaginatedPipeline):
warnings.warn( warnings.warn(
"The usage of StandardPdfPipeline.download_models_hf() is deprecated " "The usage of StandardPdfPipeline.download_models_hf() is deprecated "
"use instead the utility `docling-tools models download`, or " "use instead the utility `docling-tools models download`, or "
"the upstream method in docling.utils.", "the upstream method docling.utils.models_downloader.download_all()",
DeprecationWarning, DeprecationWarning,
stacklevel=3, stacklevel=3,
) )
if local_dir is None: output_dir = download_all(output_dir=local_dir, force=force, progress=False)
local_dir = settings.cache_dir / "models" return output_dir
# Make sure the folder exists
local_dir.mkdir(exist_ok=True, parents=True)
# Download model weights
LayoutModel.download_models(
local_dir=local_dir / LayoutModel._model_repo_folder, force=force
)
TableStructureModel.download_models(
local_dir=local_dir / TableStructureModel._model_repo_folder, force=force
)
DocumentPictureClassifier.download_models(
local_dir=local_dir / DocumentPictureClassifier._model_repo_folder,
force=force,
)
CodeFormulaModel.download_models(
local_dir=local_dir / CodeFormulaModel._model_repo_folder, force=force
)
return local_dir
def get_ocr_model( def get_ocr_model(
self, artifacts_path: Optional[Path] = None self, artifacts_path: Optional[Path] = None

View File

@ -0,0 +1,74 @@
import logging
from pathlib import Path
from typing import Optional
from docling.datamodel.settings import settings
from docling.models.code_formula_model import CodeFormulaModel
from docling.models.document_picture_classifier import DocumentPictureClassifier
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.table_structure_model import TableStructureModel
_log = logging.getLogger(__name__)
def download_all(
output_dir: Optional[Path] = None,
*,
force: bool = False,
progress: bool = False,
layout: bool = True,
tableformer: bool = True,
code_formula: bool = True,
picture_classifier: bool = True,
easyocr: bool = True,
rapidocr: bool = True,
):
if output_dir is None:
output_dir = settings.cache_dir / "models"
# Make sure the folder exists
output_dir.mkdir(exist_ok=True, parents=True)
if layout:
_log.info(f"Downloading layout model...")
LayoutModel.download_models(
local_dir=output_dir / LayoutModel._model_repo_folder,
force=force,
progress=progress,
)
if tableformer:
_log.info(f"Downloading tableformer model...")
TableStructureModel.download_models(
local_dir=output_dir / TableStructureModel._model_repo_folder,
force=force,
progress=progress,
)
if picture_classifier:
_log.info(f"Downloading picture classifier model...")
DocumentPictureClassifier.download_models(
local_dir=output_dir / DocumentPictureClassifier._model_repo_folder,
force=force,
progress=progress,
)
if code_formula:
_log.info(f"Downloading code formula model...")
CodeFormulaModel.download_models(
local_dir=output_dir / CodeFormulaModel._model_repo_folder,
force=force,
progress=progress,
)
if easyocr:
_log.info(f"Downloading easyocr models...")
EasyOcrModel.download_models(
local_dir=output_dir / EasyOcrModel._model_repo_folder,
force=force,
progress=progress,
)
return output_dir