From 5c69081453a7e8a2f2139fb858da6aa1c26ae6fd Mon Sep 17 00:00:00 2001 From: Nikos Livathinos Date: Tue, 10 Dec 2024 15:23:56 +0000 Subject: [PATCH] fix: Ocr AccleratorDevice Signed-off-by: Nikos Livathinos Signed-off-by: Christoph Auer --- docling/models/easyocr_model.py | 4 ++-- docling/models/rapid_ocr_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index a10f5ba2..4387cd82 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -47,8 +47,8 @@ class EasyOcrModel(BaseOcrModel): # Enable easyocr GPU if running on CUDA, MPS use_gpu = any( filter( - lambda x: str(x) in device, - ["cuda", "mps"], + lambda x: str(x).lower() in device, + [AcceleratorDevice.CUDA.value, AcceleratorDevice.MPS.value], ) ) diff --git a/docling/models/rapid_ocr_model.py b/docling/models/rapid_ocr_model.py index ec471253..fe8e2446 100644 --- a/docling/models/rapid_ocr_model.py +++ b/docling/models/rapid_ocr_model.py @@ -42,7 +42,7 @@ class RapidOcrModel(BaseOcrModel): # Decide the accelerator devices device = decide_device(accelerator_options.device) - use_cuda = "cuda" in device + use_cuda = str(AcceleratorDevice.CUDA.value).lower() in device use_dml = accelerator_options.device == AcceleratorDevice.AUTO intra_op_num_threads = accelerator_options.num_threads