mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-11 06:08:09 +00:00
feat: Support cuda:n GPU device allocation (#694)
* Adding multi-gpu support, and cuda device allocation Signed-off-by: ahn <ahn@zurich.ibm.com> * Fixes pydantic exception with cuda:n Signed-off-by: ahn <ahn@zurich.ibm.com> * Pydantic field validator and comment restored. Signed-off-by: ahn <ahn@zurich.ibm.com> * chore: Accept AcceleratorDevice enum type Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Resetted some options to default, removed EasyOCR model wrap. Signed-off-by: ahn <ahn@zurich.ibm.com> * Fixed rebased issues Signed-off-by: ahn <ahn@zurich.ibm.com> * Revert accelerator test options Signed-off-by: ahn <ahn@zurich.ibm.com> --------- Signed-off-by: ahn <ahn@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: ahn <ahn@sonny.zuvela.ibm.com> Co-authored-by: ahn <ahn@zurich.ibm.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -7,36 +7,62 @@ from docling.datamodel.pipeline_options import AcceleratorDevice
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decide_device(accelerator_device: AcceleratorDevice) -> str:
|
||||
def decide_device(accelerator_device: str) -> str:
|
||||
r"""
|
||||
Resolve the device based on the acceleration options and the available devices in the system
|
||||
Resolve the device based on the acceleration options and the available devices in the system.
|
||||
|
||||
Rules:
|
||||
1. AUTO: Check for the best available device on the system.
|
||||
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
|
||||
"""
|
||||
cuda_index = 0
|
||||
device = "cpu"
|
||||
|
||||
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
|
||||
if accelerator_device == AcceleratorDevice.AUTO:
|
||||
if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto'
|
||||
if has_cuda:
|
||||
device = f"cuda:{cuda_index}"
|
||||
device = "cuda:0"
|
||||
elif has_mps:
|
||||
device = "mps"
|
||||
|
||||
elif accelerator_device.startswith("cuda"):
|
||||
if has_cuda:
|
||||
# if cuda device index specified extract device id
|
||||
parts = accelerator_device.split(":")
|
||||
if len(parts) == 2 and parts[1].isdigit():
|
||||
# select cuda device's id
|
||||
cuda_index = int(parts[1])
|
||||
if cuda_index < torch.cuda.device_count():
|
||||
device = f"cuda:{cuda_index}"
|
||||
else:
|
||||
_log.warning(
|
||||
"CUDA device 'cuda:%d' is not available. Fall back to 'CPU'.",
|
||||
cuda_index,
|
||||
)
|
||||
elif len(parts) == 1: # just "cuda"
|
||||
device = "cuda:0"
|
||||
else:
|
||||
_log.warning(
|
||||
"Invalid CUDA device format '%s'. Fall back to 'CPU'",
|
||||
accelerator_device,
|
||||
)
|
||||
else:
|
||||
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
|
||||
|
||||
elif accelerator_device == AcceleratorDevice.MPS.value:
|
||||
if has_mps:
|
||||
device = "mps"
|
||||
else:
|
||||
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
|
||||
|
||||
elif accelerator_device == AcceleratorDevice.CPU.value:
|
||||
device = "cpu"
|
||||
|
||||
else:
|
||||
if accelerator_device == AcceleratorDevice.CUDA:
|
||||
if has_cuda:
|
||||
device = f"cuda:{cuda_index}"
|
||||
else:
|
||||
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
|
||||
elif accelerator_device == AcceleratorDevice.MPS:
|
||||
if has_mps:
|
||||
device = "mps"
|
||||
else:
|
||||
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
|
||||
_log.warning(
|
||||
"Unknown device option '%s'. Fall back to 'CPU'", accelerator_device
|
||||
)
|
||||
|
||||
_log.info("Accelerator device: '%s'", device)
|
||||
return device
|
||||
|
||||
Reference in New Issue
Block a user