add --all flag to model download CLI

Signed-off-by: Panos Vagenas <pva@zurich.ibm.com>
This commit is contained in:
Panos Vagenas 2025-02-26 13:00:49 +01:00
parent 560164f613
commit abd714b64b

View File

@ -57,14 +57,23 @@ def download(
), ),
] = (settings.cache_dir / "models"), ] = (settings.cache_dir / "models"),
force: Annotated[ force: Annotated[
bool, typer.Option(..., help="If true, the download will be forced") bool, typer.Option(..., help="If true, the download will be forced.")
] = False, ] = False,
models: Annotated[ models: Annotated[
Optional[list[_AvailableModels]], Optional[list[_AvailableModels]],
typer.Argument( typer.Argument(
help=f"Models to download (default behavior: a predefined set of models will be downloaded)", help=f"Models to download (default behavior: a predefined set of models will be downloaded).",
), ),
] = None, ] = None,
all: Annotated[
bool,
typer.Option(
...,
"--all",
help="If true, all available models will be downloaded (mutually exclusive with passing specific models).",
show_default=True,
),
] = False,
quiet: Annotated[ quiet: Annotated[
bool, bool,
typer.Option( typer.Option(
@ -75,6 +84,10 @@ def download(
), ),
] = False, ] = False,
): ):
if models and all:
raise typer.BadParameter(
"Cannot simultaneously set 'all' parameter and specify models to download."
)
if not quiet: if not quiet:
FORMAT = "%(message)s" FORMAT = "%(message)s"
logging.basicConfig( logging.basicConfig(
@ -83,7 +96,7 @@ def download(
datefmt="[%X]", datefmt="[%X]",
handlers=[RichHandler(show_level=False, show_time=False, markup=True)], handlers=[RichHandler(show_level=False, show_time=False, markup=True)],
) )
to_download = models or _default_models to_download = models or ([m for m in _AvailableModels] if all else _default_models)
output_dir = download_models( output_dir = download_models(
output_dir=output_dir, output_dir=output_dir,
force=force, force=force,