mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-01 15:02:21 +00:00
Fixes pydantic exception with cuda:n
Signed-off-by: ahn <ahn@zurich.ibm.com>
This commit is contained in:
parent
fd51a7fa1f
commit
b9668877be
@ -39,18 +39,15 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_threads: int = 4
|
num_threads: int = 4
|
||||||
device: AcceleratorDevice = AcceleratorDevice.AUTO
|
device: str = "auto"
|
||||||
|
|
||||||
@validator("device")
|
@validator("device")
|
||||||
def validate_device(cls, value):
|
def validate_device(cls, value):
|
||||||
# Allow both Enum and str inputs
|
# "auto", "cpu", "cuda", "mps", or "cuda:N"
|
||||||
if isinstance(value, AcceleratorDevice):
|
|
||||||
return value
|
|
||||||
# Validate as a string
|
|
||||||
if value in {d.value for d in AcceleratorDevice} or re.match(
|
if value in {d.value for d in AcceleratorDevice} or re.match(
|
||||||
r"^cuda(:\d+)?$", value
|
r"^cuda(:\d+)?$", value
|
||||||
):
|
):
|
||||||
return AcceleratorDevice(value)
|
return value
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
|
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
|
||||||
)
|
)
|
||||||
@ -58,19 +55,8 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_alternative_envvars(cls, data: Any) -> Any:
|
def check_alternative_envvars(cls, data: Any) -> Any:
|
||||||
r"""
|
|
||||||
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
|
|
||||||
The alternative envvar is used only if it is valid and the regular envvar is not set.
|
|
||||||
|
|
||||||
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
|
|
||||||
the same functionality. In case the alias envvar is set and the user tries to override the
|
|
||||||
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
|
|
||||||
as an extra input instead of simply overwriting the evvar value for that parameter.
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
input_num_threads = data.get("num_threads")
|
input_num_threads = data.get("num_threads")
|
||||||
|
|
||||||
# Check if to set the num_threads from the alternative envvar
|
|
||||||
if input_num_threads is None:
|
if input_num_threads is None:
|
||||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||||
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
||||||
|
Loading…
Reference in New Issue
Block a user