mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-01 15:02:21 +00:00
Nail the accelerator defaults for MPS
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
6832c5aeba
commit
24f0346d84
@ -1,15 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@ -134,14 +136,8 @@ class EasyOcrOptions(OcrOptions):
|
||||
|
||||
kind: Literal["easyocr"] = "easyocr"
|
||||
lang: List[str] = ["fr", "de", "es", "en"]
|
||||
use_gpu: Annotated[
|
||||
bool,
|
||||
Field(
|
||||
deprecated="Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
|
||||
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
|
||||
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
|
||||
),
|
||||
] = True
|
||||
|
||||
use_gpu: Optional[bool] = None
|
||||
|
||||
model_storage_directory: Optional[str] = None
|
||||
download_enabled: bool = True
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
@ -41,16 +42,25 @@ class EasyOcrModel(BaseOcrModel):
|
||||
"Alternatively, Docling has support for other OCR engines. See the documentation."
|
||||
)
|
||||
|
||||
use_gpu = False
|
||||
if self.options.use_gpu:
|
||||
if self.options.use_gpu is None:
|
||||
device = decide_device(accelerator_options.device)
|
||||
# Enable easyocr GPU if running on CUDA, MPS
|
||||
use_gpu = any(
|
||||
filter(
|
||||
lambda x: str(x).lower() in device,
|
||||
[AcceleratorDevice.CUDA.value, AcceleratorDevice.MPS.value],
|
||||
)
|
||||
[
|
||||
device.startswith(x)
|
||||
for x in [
|
||||
AcceleratorDevice.CUDA.value,
|
||||
AcceleratorDevice.MPS.value,
|
||||
]
|
||||
]
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
|
||||
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
|
||||
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
|
||||
)
|
||||
use_gpu = self.options.use_gpu
|
||||
|
||||
self.reader = easyocr.Reader(
|
||||
lang_list=self.options.lang,
|
||||
|
@ -18,7 +18,7 @@ from docling.datamodel.base_models import (
|
||||
Page,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
@ -50,8 +50,9 @@ class LayoutModel(BasePageModel):
|
||||
|
||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
artifact_path=artifacts_path,
|
||||
artifact_path=str(artifacts_path),
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
base_threshold=0.6,
|
||||
|
@ -10,6 +10,7 @@ from PIL import ImageDraw
|
||||
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
TableFormerMode,
|
||||
TableStructureOptions,
|
||||
@ -44,6 +45,10 @@ class TableStructureModel(BasePageModel):
|
||||
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
# Disable MPS here, until we know why it makes things slower.
|
||||
if device == AcceleratorDevice.MPS.value:
|
||||
device = AcceleratorDevice.CPU.value
|
||||
|
||||
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
||||
self.tm_config["model"]["save_dir"] = artifacts_path
|
||||
self.tm_model_type = self.tm_config["model"]["type"]
|
||||
|
@ -21,9 +21,11 @@ def decide_device(accelerator_device: AcceleratorDevice) -> str:
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
|
||||
if accelerator_device == AcceleratorDevice.AUTO:
|
||||
# TODO: Enable MPS later
|
||||
if has_cuda:
|
||||
device = f"cuda:{cuda_index}"
|
||||
elif has_mps:
|
||||
device = "mps"
|
||||
|
||||
else:
|
||||
if accelerator_device == AcceleratorDevice.CUDA:
|
||||
if has_cuda:
|
||||
|
Loading…
Reference in New Issue
Block a user