fix: Ocr AccleratorDevice

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Nikos Livathinos 2024-12-10 15:23:56 +00:00
parent 6bc1bd2ec4
commit 5c69081453
2 changed files with 3 additions and 3 deletions

View File

@ -47,8 +47,8 @@ class EasyOcrModel(BaseOcrModel):
# Enable easyocr GPU if running on CUDA, MPS
use_gpu = any(
filter(
lambda x: str(x) in device,
["cuda", "mps"],
lambda x: str(x).lower() in device,
[AcceleratorDevice.CUDA.value, AcceleratorDevice.MPS.value],
)
)

View File

@ -42,7 +42,7 @@ class RapidOcrModel(BaseOcrModel):
# Decide the accelerator devices
device = decide_device(accelerator_options.device)
use_cuda = "cuda" in device
use_cuda = str(AcceleratorDevice.CUDA.value).lower() in device
use_dml = accelerator_options.device == AcceleratorDevice.AUTO
intra_op_num_threads = accelerator_options.num_threads