Fixes pydantic exception with cuda:n

Signed-off-by: ahn <ahn@zurich.ibm.com>
This commit is contained in:
ahn 2025-01-07 14:41:38 +01:00
parent fd51a7fa1f
commit b9668877be

View File

@ -39,18 +39,15 @@ class AcceleratorOptions(BaseSettings):
)
num_threads: int = 4
device: AcceleratorDevice = AcceleratorDevice.AUTO
device: str = "auto"
@validator("device")
def validate_device(cls, value):
# Allow both Enum and str inputs
if isinstance(value, AcceleratorDevice):
return value
# Validate as a string
# "auto", "cpu", "cuda", "mps", or "cuda:N"
if value in {d.value for d in AcceleratorDevice} or re.match(
r"^cuda(:\d+)?$", value
):
return AcceleratorDevice(value)
return value
raise ValueError(
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
)
@ -58,19 +55,8 @@ class AcceleratorOptions(BaseSettings):
@model_validator(mode="before")
@classmethod
def check_alternative_envvars(cls, data: Any) -> Any:
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
the same functionality. In case the alias envvar is set and the user tries to override the
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
omp_num_threads = os.getenv("OMP_NUM_THREADS")