Fixes pydantic exception with cuda:n

Signed-off-by: ahn <ahn@zurich.ibm.com>
This commit is contained in:
ahn 2025-01-07 14:41:38 +01:00
parent fd51a7fa1f
commit b9668877be

View File

@ -39,18 +39,15 @@ class AcceleratorOptions(BaseSettings):
) )
num_threads: int = 4 num_threads: int = 4
device: AcceleratorDevice = AcceleratorDevice.AUTO device: str = "auto"
@validator("device") @validator("device")
def validate_device(cls, value): def validate_device(cls, value):
# Allow both Enum and str inputs # "auto", "cpu", "cuda", "mps", or "cuda:N"
if isinstance(value, AcceleratorDevice):
return value
# Validate as a string
if value in {d.value for d in AcceleratorDevice} or re.match( if value in {d.value for d in AcceleratorDevice} or re.match(
r"^cuda(:\d+)?$", value r"^cuda(:\d+)?$", value
): ):
return AcceleratorDevice(value) return value
raise ValueError( raise ValueError(
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." "Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
) )
@ -58,19 +55,8 @@ class AcceleratorOptions(BaseSettings):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_alternative_envvars(cls, data: Any) -> Any: def check_alternative_envvars(cls, data: Any) -> Any:
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
the same functionality. In case the alias envvar is set and the user tries to override the
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
if isinstance(data, dict): if isinstance(data, dict):
input_num_threads = data.get("num_threads") input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None: if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS") docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
omp_num_threads = os.getenv("OMP_NUM_THREADS") omp_num_threads = os.getenv("OMP_NUM_THREADS")