Nail the accelerator defaults for MPS

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-12-13 17:19:10 +01:00
parent 6832c5aeba
commit 24f0346d84
5 changed files with 32 additions and 18 deletions

View File

@ -1,15 +1,17 @@
import logging
import os
import warnings
from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
from typing_extensions import deprecated
_log = logging.getLogger(__name__)
@ -134,14 +136,8 @@ class EasyOcrOptions(OcrOptions):
kind: Literal["easyocr"] = "easyocr"
lang: List[str] = ["fr", "de", "es", "en"]
use_gpu: Annotated[
bool,
Field(
deprecated="Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
),
] = True
use_gpu: Optional[bool] = None
model_storage_directory: Optional[str] = None
download_enabled: bool = True

View File

@ -1,4 +1,5 @@
import logging
import warnings
from typing import Iterable
import numpy
@ -41,16 +42,25 @@ class EasyOcrModel(BaseOcrModel):
"Alternatively, Docling has support for other OCR engines. See the documentation."
)
use_gpu = False
if self.options.use_gpu:
if self.options.use_gpu is None:
device = decide_device(accelerator_options.device)
# Enable easyocr GPU if running on CUDA, MPS
use_gpu = any(
filter(
lambda x: str(x).lower() in device,
[AcceleratorDevice.CUDA.value, AcceleratorDevice.MPS.value],
)
[
device.startswith(x)
for x in [
AcceleratorDevice.CUDA.value,
AcceleratorDevice.MPS.value,
]
]
)
else:
warnings.warn(
"Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
)
use_gpu = self.options.use_gpu
self.reader = easyocr.Reader(
lang_list=self.options.lang,

View File

@ -18,7 +18,7 @@ from docling.datamodel.base_models import (
Page,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
@ -50,8 +50,9 @@ class LayoutModel(BasePageModel):
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
device = decide_device(accelerator_options.device)
self.layout_predictor = LayoutPredictor(
artifact_path=artifacts_path,
artifact_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
base_threshold=0.6,

View File

@ -10,6 +10,7 @@ from PIL import ImageDraw
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
TableFormerMode,
TableStructureOptions,
@ -44,6 +45,10 @@ class TableStructureModel(BasePageModel):
device = decide_device(accelerator_options.device)
# Disable MPS here, until we know why it makes things slower.
if device == AcceleratorDevice.MPS.value:
device = AcceleratorDevice.CPU.value
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
self.tm_config["model"]["save_dir"] = artifacts_path
self.tm_model_type = self.tm_config["model"]["type"]

View File

@ -21,9 +21,11 @@ def decide_device(accelerator_device: AcceleratorDevice) -> str:
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
if accelerator_device == AcceleratorDevice.AUTO:
# TODO: Enable MPS later
if has_cuda:
device = f"cuda:{cuda_index}"
elif has_mps:
device = "mps"
else:
if accelerator_device == AcceleratorDevice.CUDA:
if has_cuda: