mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-02 15:32:30 +00:00
move function to utils
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
a11fd1f157
commit
1b8adb860a
@ -12,12 +12,14 @@ from docling.models.easyocr_model import EasyOcrModel
|
|||||||
from docling.models.layout_model import LayoutModel
|
from docling.models.layout_model import LayoutModel
|
||||||
from docling.models.rapid_ocr_model import RapidOcrModel
|
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||||
from docling.models.table_structure_model import TableStructureModel
|
from docling.models.table_structure_model import TableStructureModel
|
||||||
|
from docling.utils.models_downloader import download_all
|
||||||
|
|
||||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
|
||||||
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
|
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
err_console = Console(stderr=True)
|
err_console = Console(stderr=True)
|
||||||
@ -85,60 +87,33 @@ def download(
|
|||||||
typer.Option(..., help="If true, the rapidocr model weights are downloaded."),
|
typer.Option(..., help="If true, the rapidocr model weights are downloaded."),
|
||||||
] = True,
|
] = True,
|
||||||
):
|
):
|
||||||
# Make sure the folder exists
|
|
||||||
output_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
show_progress = not quite
|
|
||||||
|
|
||||||
if layout:
|
|
||||||
if not quite:
|
if not quite:
|
||||||
typer.secho(f"Downloading layout model...", fg="blue")
|
FORMAT = "%(message)s"
|
||||||
LayoutModel.download_models(
|
logging.basicConfig(
|
||||||
local_dir=output_dir / LayoutModel._model_repo_folder,
|
level=logging.INFO,
|
||||||
force=force,
|
format="[blue]%(message)s[/blue]",
|
||||||
progress=show_progress,
|
datefmt="[%X]",
|
||||||
|
handlers=[RichHandler(show_level=False, show_time=False, markup=True)],
|
||||||
)
|
)
|
||||||
|
|
||||||
if tableformer:
|
output_dir = download_all(
|
||||||
if not quite:
|
output_dir=output_dir,
|
||||||
typer.secho(f"Downloading tableformer model...", fg="blue")
|
|
||||||
TableStructureModel.download_models(
|
|
||||||
local_dir=output_dir / TableStructureModel._model_repo_folder,
|
|
||||||
force=force,
|
force=force,
|
||||||
progress=show_progress,
|
progress=(not quite),
|
||||||
)
|
layout=layout,
|
||||||
|
tableformer=tableformer,
|
||||||
if picture_classifier:
|
code_formula=code_formula,
|
||||||
if not quite:
|
picture_classifier=picture_classifier,
|
||||||
typer.secho(f"Downloading picture classifier model...", fg="blue")
|
easyocr=easyocr,
|
||||||
DocumentPictureClassifier.download_models(
|
rapidocr=rapidocr,
|
||||||
local_dir=output_dir / DocumentPictureClassifier._model_repo_folder,
|
|
||||||
force=force,
|
|
||||||
progress=show_progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
if code_formula:
|
|
||||||
if not quite:
|
|
||||||
typer.secho(f"Downloading code formula model...", fg="blue")
|
|
||||||
CodeFormulaModel.download_models(
|
|
||||||
local_dir=output_dir / CodeFormulaModel._model_repo_folder,
|
|
||||||
force=force,
|
|
||||||
progress=show_progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
if easyocr:
|
|
||||||
if not quite:
|
|
||||||
typer.secho(f"Downloading easyocr models...", fg="blue")
|
|
||||||
EasyOcrModel.download_models(
|
|
||||||
local_dir=output_dir / EasyOcrModel._model_repo_folder,
|
|
||||||
force=force,
|
|
||||||
progress=show_progress,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if quite:
|
if quite:
|
||||||
typer.echo(output_dir)
|
typer.echo(output_dir)
|
||||||
else:
|
else:
|
||||||
typer.secho(f"All models downloaded in the directory {output_dir}.", fg="green")
|
typer.secho(
|
||||||
|
f"\nAll models downloaded in the directory {output_dir}.", fg="green"
|
||||||
|
)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
"\n",
|
"\n",
|
||||||
@ -150,11 +125,6 @@ def download(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# @app.command(hidden=True)
|
|
||||||
# def other():
|
|
||||||
# raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
click_app = typer.main.get_command(app)
|
click_app = typer.main.get_command(app)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -39,6 +39,7 @@ from docling.models.table_structure_model import TableStructureModel
|
|||||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||||
|
from docling.utils.models_downloader import download_all
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
@ -129,33 +130,13 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The usage of StandardPdfPipeline.download_models_hf() is deprecated "
|
"The usage of StandardPdfPipeline.download_models_hf() is deprecated "
|
||||||
"use instead the utility `docling-tools models download`, or "
|
"use instead the utility `docling-tools models download`, or "
|
||||||
"the upstream method in docling.utils.",
|
"the upstream method docling.utils.models_downloader.download_all()",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=3,
|
stacklevel=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
if local_dir is None:
|
output_dir = download_all(output_dir=local_dir, force=force, progress=False)
|
||||||
local_dir = settings.cache_dir / "models"
|
return output_dir
|
||||||
|
|
||||||
# Make sure the folder exists
|
|
||||||
local_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
# Download model weights
|
|
||||||
LayoutModel.download_models(
|
|
||||||
local_dir=local_dir / LayoutModel._model_repo_folder, force=force
|
|
||||||
)
|
|
||||||
TableStructureModel.download_models(
|
|
||||||
local_dir=local_dir / TableStructureModel._model_repo_folder, force=force
|
|
||||||
)
|
|
||||||
DocumentPictureClassifier.download_models(
|
|
||||||
local_dir=local_dir / DocumentPictureClassifier._model_repo_folder,
|
|
||||||
force=force,
|
|
||||||
)
|
|
||||||
CodeFormulaModel.download_models(
|
|
||||||
local_dir=local_dir / CodeFormulaModel._model_repo_folder, force=force
|
|
||||||
)
|
|
||||||
|
|
||||||
return local_dir
|
|
||||||
|
|
||||||
def get_ocr_model(
|
def get_ocr_model(
|
||||||
self, artifacts_path: Optional[Path] = None
|
self, artifacts_path: Optional[Path] = None
|
||||||
|
74
docling/utils/models_downloader.py
Normal file
74
docling/utils/models_downloader.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from docling.datamodel.settings import settings
|
||||||
|
from docling.models.code_formula_model import CodeFormulaModel
|
||||||
|
from docling.models.document_picture_classifier import DocumentPictureClassifier
|
||||||
|
from docling.models.easyocr_model import EasyOcrModel
|
||||||
|
from docling.models.layout_model import LayoutModel
|
||||||
|
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||||
|
from docling.models.table_structure_model import TableStructureModel
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def download_all(
|
||||||
|
output_dir: Optional[Path] = None,
|
||||||
|
*,
|
||||||
|
force: bool = False,
|
||||||
|
progress: bool = False,
|
||||||
|
layout: bool = True,
|
||||||
|
tableformer: bool = True,
|
||||||
|
code_formula: bool = True,
|
||||||
|
picture_classifier: bool = True,
|
||||||
|
easyocr: bool = True,
|
||||||
|
rapidocr: bool = True,
|
||||||
|
):
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = settings.cache_dir / "models"
|
||||||
|
|
||||||
|
# Make sure the folder exists
|
||||||
|
output_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
if layout:
|
||||||
|
_log.info(f"Downloading layout model...")
|
||||||
|
LayoutModel.download_models(
|
||||||
|
local_dir=output_dir / LayoutModel._model_repo_folder,
|
||||||
|
force=force,
|
||||||
|
progress=progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tableformer:
|
||||||
|
_log.info(f"Downloading tableformer model...")
|
||||||
|
TableStructureModel.download_models(
|
||||||
|
local_dir=output_dir / TableStructureModel._model_repo_folder,
|
||||||
|
force=force,
|
||||||
|
progress=progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
if picture_classifier:
|
||||||
|
_log.info(f"Downloading picture classifier model...")
|
||||||
|
DocumentPictureClassifier.download_models(
|
||||||
|
local_dir=output_dir / DocumentPictureClassifier._model_repo_folder,
|
||||||
|
force=force,
|
||||||
|
progress=progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
if code_formula:
|
||||||
|
_log.info(f"Downloading code formula model...")
|
||||||
|
CodeFormulaModel.download_models(
|
||||||
|
local_dir=output_dir / CodeFormulaModel._model_repo_folder,
|
||||||
|
force=force,
|
||||||
|
progress=progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
if easyocr:
|
||||||
|
_log.info(f"Downloading easyocr models...")
|
||||||
|
EasyOcrModel.download_models(
|
||||||
|
local_dir=output_dir / EasyOcrModel._model_repo_folder,
|
||||||
|
force=force,
|
||||||
|
progress=progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_dir
|
Loading…
Reference in New Issue
Block a user