mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-02 15:32:30 +00:00
Nail the accelerator defaults for MPS
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
6832c5aeba
commit
24f0346d84
@ -1,15 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
from pydantic_settings import (
|
from pydantic_settings import (
|
||||||
BaseSettings,
|
BaseSettings,
|
||||||
PydanticBaseSettingsSource,
|
PydanticBaseSettingsSource,
|
||||||
SettingsConfigDict,
|
SettingsConfigDict,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -134,14 +136,8 @@ class EasyOcrOptions(OcrOptions):
|
|||||||
|
|
||||||
kind: Literal["easyocr"] = "easyocr"
|
kind: Literal["easyocr"] = "easyocr"
|
||||||
lang: List[str] = ["fr", "de", "es", "en"]
|
lang: List[str] = ["fr", "de", "es", "en"]
|
||||||
use_gpu: Annotated[
|
|
||||||
bool,
|
use_gpu: Optional[bool] = None
|
||||||
Field(
|
|
||||||
deprecated="Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
|
|
||||||
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
|
|
||||||
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
|
|
||||||
),
|
|
||||||
] = True
|
|
||||||
|
|
||||||
model_storage_directory: Optional[str] = None
|
model_storage_directory: Optional[str] = None
|
||||||
download_enabled: bool = True
|
download_enabled: bool = True
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@ -41,16 +42,25 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
"Alternatively, Docling has support for other OCR engines. See the documentation."
|
"Alternatively, Docling has support for other OCR engines. See the documentation."
|
||||||
)
|
)
|
||||||
|
|
||||||
use_gpu = False
|
if self.options.use_gpu is None:
|
||||||
if self.options.use_gpu:
|
|
||||||
device = decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
# Enable easyocr GPU if running on CUDA, MPS
|
# Enable easyocr GPU if running on CUDA, MPS
|
||||||
use_gpu = any(
|
use_gpu = any(
|
||||||
filter(
|
[
|
||||||
lambda x: str(x).lower() in device,
|
device.startswith(x)
|
||||||
[AcceleratorDevice.CUDA.value, AcceleratorDevice.MPS.value],
|
for x in [
|
||||||
|
AcceleratorDevice.CUDA.value,
|
||||||
|
AcceleratorDevice.MPS.value,
|
||||||
|
]
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"Deprecated field. Better to set the `accelerator_options.device` in `pipeline_options`. "
|
||||||
|
"When `use_gpu and accelerator_options.device == AcceleratorDevice.CUDA` the GPU is used "
|
||||||
|
"to run EasyOCR. Otherwise, EasyOCR runs in CPU."
|
||||||
)
|
)
|
||||||
|
use_gpu = self.options.use_gpu
|
||||||
|
|
||||||
self.reader = easyocr.Reader(
|
self.reader = easyocr.Reader(
|
||||||
lang_list=self.options.lang,
|
lang_list=self.options.lang,
|
||||||
|
@ -18,7 +18,7 @@ from docling.datamodel.base_models import (
|
|||||||
Page,
|
Page,
|
||||||
)
|
)
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel
|
||||||
from docling.utils.accelerator_utils import decide_device
|
from docling.utils.accelerator_utils import decide_device
|
||||||
@ -50,8 +50,9 @@ class LayoutModel(BasePageModel):
|
|||||||
|
|
||||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||||
device = decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
|
|
||||||
self.layout_predictor = LayoutPredictor(
|
self.layout_predictor = LayoutPredictor(
|
||||||
artifact_path=artifacts_path,
|
artifact_path=str(artifacts_path),
|
||||||
device=device,
|
device=device,
|
||||||
num_threads=accelerator_options.num_threads,
|
num_threads=accelerator_options.num_threads,
|
||||||
base_threshold=0.6,
|
base_threshold=0.6,
|
||||||
|
@ -10,6 +10,7 @@ from PIL import ImageDraw
|
|||||||
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
|
AcceleratorDevice,
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
TableFormerMode,
|
TableFormerMode,
|
||||||
TableStructureOptions,
|
TableStructureOptions,
|
||||||
@ -44,6 +45,10 @@ class TableStructureModel(BasePageModel):
|
|||||||
|
|
||||||
device = decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
|
|
||||||
|
# Disable MPS here, until we know why it makes things slower.
|
||||||
|
if device == AcceleratorDevice.MPS.value:
|
||||||
|
device = AcceleratorDevice.CPU.value
|
||||||
|
|
||||||
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
||||||
self.tm_config["model"]["save_dir"] = artifacts_path
|
self.tm_config["model"]["save_dir"] = artifacts_path
|
||||||
self.tm_model_type = self.tm_config["model"]["type"]
|
self.tm_model_type = self.tm_config["model"]["type"]
|
||||||
|
@ -21,9 +21,11 @@ def decide_device(accelerator_device: AcceleratorDevice) -> str:
|
|||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
|
|
||||||
if accelerator_device == AcceleratorDevice.AUTO:
|
if accelerator_device == AcceleratorDevice.AUTO:
|
||||||
# TODO: Enable MPS later
|
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
device = f"cuda:{cuda_index}"
|
device = f"cuda:{cuda_index}"
|
||||||
|
elif has_mps:
|
||||||
|
device = "mps"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if accelerator_device == AcceleratorDevice.CUDA:
|
if accelerator_device == AcceleratorDevice.CUDA:
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
|
Loading…
Reference in New Issue
Block a user