diff --git a/docling/cli/models_download.py b/docling/cli/models_download.py index c5c4d37b..2db36521 100644 --- a/docling/cli/models_download.py +++ b/docling/cli/models_download.py @@ -12,12 +12,14 @@ from docling.models.easyocr_model import EasyOcrModel from docling.models.layout_model import LayoutModel from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.table_structure_model import TableStructureModel +from docling.utils.models_downloader import download_all warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch") warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr") _log = logging.getLogger(__name__) from rich.console import Console +from rich.logging import RichHandler console = Console() err_console = Console(stderr=True) @@ -85,60 +87,33 @@ def download( typer.Option(..., help="If true, the rapidocr model weights are downloaded."), ] = True, ): - # Make sure the folder exists - output_dir.mkdir(exist_ok=True, parents=True) - - show_progress = not quite - - if layout: - if not quite: - typer.secho(f"Downloading layout model...", fg="blue") - LayoutModel.download_models( - local_dir=output_dir / LayoutModel._model_repo_folder, - force=force, - progress=show_progress, + if not quite: + FORMAT = "%(message)s" + logging.basicConfig( + level=logging.INFO, + format="[blue]%(message)s[/blue]", + datefmt="[%X]", + handlers=[RichHandler(show_level=False, show_time=False, markup=True)], ) - if tableformer: - if not quite: - typer.secho(f"Downloading tableformer model...", fg="blue") - TableStructureModel.download_models( - local_dir=output_dir / TableStructureModel._model_repo_folder, - force=force, - progress=show_progress, - ) - - if picture_classifier: - if not quite: - typer.secho(f"Downloading picture classifier model...", fg="blue") - DocumentPictureClassifier.download_models( - local_dir=output_dir / DocumentPictureClassifier._model_repo_folder, - force=force, - progress=show_progress, - ) - - if code_formula: - if not quite: - typer.secho(f"Downloading code formula model...", fg="blue") - CodeFormulaModel.download_models( - local_dir=output_dir / CodeFormulaModel._model_repo_folder, - force=force, - progress=show_progress, - ) - - if easyocr: - if not quite: - typer.secho(f"Downloading easyocr models...", fg="blue") - EasyOcrModel.download_models( - local_dir=output_dir / EasyOcrModel._model_repo_folder, - force=force, - progress=show_progress, - ) + output_dir = download_all( + output_dir=output_dir, + force=force, + progress=(not quite), + layout=layout, + tableformer=tableformer, + code_formula=code_formula, + picture_classifier=picture_classifier, + easyocr=easyocr, + rapidocr=rapidocr, + ) if quite: typer.echo(output_dir) else: - typer.secho(f"All models downloaded in the directory {output_dir}.", fg="green") + typer.secho( + f"\nAll models downloaded in the directory {output_dir}.", fg="green" + ) console.print( "\n", @@ -150,11 +125,6 @@ def download( ) -# @app.command(hidden=True) -# def other(): -# raise NotImplementedError() - - click_app = typer.main.get_command(app) if __name__ == "__main__": diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index fbdc22c3..85bd1075 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -39,6 +39,7 @@ from docling.models.table_structure_model import TableStructureModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel from docling.models.tesseract_ocr_model import TesseractOcrModel from docling.pipeline.base_pipeline import PaginatedPipeline +from docling.utils.models_downloader import download_all from docling.utils.profiling import ProfilingScope, TimeRecorder _log = logging.getLogger(__name__) @@ -129,33 +130,13 @@ class StandardPdfPipeline(PaginatedPipeline): warnings.warn( "The usage of StandardPdfPipeline.download_models_hf() is deprecated " "use instead the utility `docling-tools models download`, or " - "the upstream method in docling.utils.", + "the upstream method docling.utils.models_downloader.download_all()", DeprecationWarning, stacklevel=3, ) - if local_dir is None: - local_dir = settings.cache_dir / "models" - - # Make sure the folder exists - local_dir.mkdir(exist_ok=True, parents=True) - - # Download model weights - LayoutModel.download_models( - local_dir=local_dir / LayoutModel._model_repo_folder, force=force - ) - TableStructureModel.download_models( - local_dir=local_dir / TableStructureModel._model_repo_folder, force=force - ) - DocumentPictureClassifier.download_models( - local_dir=local_dir / DocumentPictureClassifier._model_repo_folder, - force=force, - ) - CodeFormulaModel.download_models( - local_dir=local_dir / CodeFormulaModel._model_repo_folder, force=force - ) - - return local_dir + output_dir = download_all(output_dir=local_dir, force=force, progress=False) + return output_dir def get_ocr_model( self, artifacts_path: Optional[Path] = None diff --git a/docling/utils/models_downloader.py b/docling/utils/models_downloader.py new file mode 100644 index 00000000..27307fd1 --- /dev/null +++ b/docling/utils/models_downloader.py @@ -0,0 +1,74 @@ +import logging +from pathlib import Path +from typing import Optional + +from docling.datamodel.settings import settings +from docling.models.code_formula_model import CodeFormulaModel +from docling.models.document_picture_classifier import DocumentPictureClassifier +from docling.models.easyocr_model import EasyOcrModel +from docling.models.layout_model import LayoutModel +from docling.models.rapid_ocr_model import RapidOcrModel +from docling.models.table_structure_model import TableStructureModel + +_log = logging.getLogger(__name__) + + +def download_all( + output_dir: Optional[Path] = None, + *, + force: bool = False, + progress: bool = False, + layout: bool = True, + tableformer: bool = True, + code_formula: bool = True, + picture_classifier: bool = True, + easyocr: bool = True, + rapidocr: bool = True, +): + if output_dir is None: + output_dir = settings.cache_dir / "models" + + # Make sure the folder exists + output_dir.mkdir(exist_ok=True, parents=True) + + if layout: + _log.info(f"Downloading layout model...") + LayoutModel.download_models( + local_dir=output_dir / LayoutModel._model_repo_folder, + force=force, + progress=progress, + ) + + if tableformer: + _log.info(f"Downloading tableformer model...") + TableStructureModel.download_models( + local_dir=output_dir / TableStructureModel._model_repo_folder, + force=force, + progress=progress, + ) + + if picture_classifier: + _log.info(f"Downloading picture classifier model...") + DocumentPictureClassifier.download_models( + local_dir=output_dir / DocumentPictureClassifier._model_repo_folder, + force=force, + progress=progress, + ) + + if code_formula: + _log.info(f"Downloading code formula model...") + CodeFormulaModel.download_models( + local_dir=output_dir / CodeFormulaModel._model_repo_folder, + force=force, + progress=progress, + ) + + if easyocr: + _log.info(f"Downloading easyocr models...") + EasyOcrModel.download_models( + local_dir=output_dir / EasyOcrModel._model_repo_folder, + force=force, + progress=progress, + ) + + return output_dir