feat(Accelerator): Introduce options to control the num_threads and device from API, envvars, CLI.

- Introduce the AcceleratorOptions, AcceleratorDevice and use them to set the device where the models run.
- Introduce the accelerator_utils with function to decide the device and resolve the AUTO setting.
- Refactor the way how the docling-ibm-models are called to match the new init signature of models.
- Translate the accelerator options to the specific inputs for third-party models.
- Extend the docling CLI with parameters to set the num_threads and device.
- Add new unit tests.
- Write new example how to use the accelerator options.

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Nikos Livathinos 2024-12-02 18:27:44 +01:00
parent 78fad801fe
commit 3bb7df66ca
11 changed files with 325 additions and 59 deletions

View File

@ -24,6 +24,8 @@ from docling.datamodel.base_models import (
) )
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ( from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions, EasyOcrOptions,
OcrMacOptions, OcrMacOptions,
OcrOptions, OcrOptions,
@ -241,6 +243,10 @@ def convert(
help="Show version information.", help="Show version information.",
), ),
] = None, ] = None,
num_threads: Annotated[int, typer.Option(..., help="Number of threads")] = 4,
device: Annotated[
AcceleratorDevice, typer.Option(..., help="Accelerator device")
] = AcceleratorDevice.AUTO,
): ):
if verbose == 0: if verbose == 0:
logging.basicConfig(level=logging.WARNING) logging.basicConfig(level=logging.WARNING)
@ -299,7 +305,9 @@ def convert(
if ocr_lang_list is not None: if ocr_lang_list is not None:
ocr_options.lang = ocr_lang_list ocr_options.lang = ocr_lang_list
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
pipeline_options = PdfPipelineOptions( pipeline_options = PdfPipelineOptions(
accelerator_options=accelerator_options,
do_ocr=ocr, do_ocr=ocr,
ocr_options=ocr_options, ocr_options=ocr_options,
do_table_structure=True, do_table_structure=True,

View File

@ -1,8 +1,62 @@
import logging
import os
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
_log = logging.getLogger(__name__)
class AcceleratorDevice(str, Enum):
"""Devices to run model inference"""
AUTO = "auto"
CPU = "cpu"
CUDA = "cuda"
MPS = "mps"
class AcceleratorOptions(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
)
# num_threads: int = Field(default=4, validation_alias="omp_num_threads")
num_threads: int = 4
device: AcceleratorDevice = AcceleratorDevice.AUTO
def __init__(self, **data):
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
the same functionality. In case the alias envvar is set and the user tries to override the
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
omp_num_threads = os.getenv("OMP_NUM_THREADS")
if docling_num_threads is None and omp_num_threads is not None:
try:
data["num_threads"] = int(omp_num_threads)
except ValueError:
_log.error(
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
omp_num_threads,
)
super().__init__(**data)
class TableFormerMode(str, Enum): class TableFormerMode(str, Enum):
@ -78,7 +132,16 @@ class EasyOcrOptions(OcrOptions):
kind: Literal["easyocr"] = "easyocr" kind: Literal["easyocr"] = "easyocr"
lang: List[str] = ["fr", "de", "es", "en"] lang: List[str] = ["fr", "de", "es", "en"]
use_gpu: bool = True # same default as easyocr.Reader use_gpu: bool = Field(
default=True,
deprecated=(
"Field `use_gpu` is deprecated. "
"When `use_gpu == True and accelerator_options.device == AcceleratorDevice.CUDA` "
"then the GPU is used to run EasyOCR. "
"When `use_gpu == False`, EasyOCR runs in CPU"
),
)
model_storage_directory: Optional[str] = None model_storage_directory: Optional[str] = None
download_enabled: bool = True # same default as easyocr.Reader download_enabled: bool = True # same default as easyocr.Reader
@ -132,6 +195,7 @@ class PipelineOptions(BaseModel):
create_legacy_output: bool = ( create_legacy_output: bool = (
True # This defautl will be set to False on a future version of docling True # This defautl will be set to False on a future version of docling
) )
accelerator_options: AcceleratorOptions = AcceleratorOptions()
class PdfPipelineOptions(PipelineOptions): class PdfPipelineOptions(PipelineOptions):

View File

@ -7,16 +7,26 @@ from docling_core.types.doc import BoundingBox, CoordOrigin
from docling.datamodel.base_models import Cell, OcrCell, Page from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import EasyOcrOptions from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
)
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel from docling.models.base_ocr_model import BaseOcrModel
from docling.utils import accelerator_utils as au
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class EasyOcrModel(BaseOcrModel): class EasyOcrModel(BaseOcrModel):
def __init__(self, enabled: bool, options: EasyOcrOptions): def __init__(
self,
enabled: bool,
options: EasyOcrOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(enabled=enabled, options=options) super().__init__(enabled=enabled, options=options)
self.options: EasyOcrOptions self.options: EasyOcrOptions
@ -31,11 +41,20 @@ class EasyOcrModel(BaseOcrModel):
"Alternatively, Docling has support for other OCR engines. See the documentation." "Alternatively, Docling has support for other OCR engines. See the documentation."
) )
use_gpu = False
if self.options.use_gpu:
device = au.decide_device(accelerator_options.device)
# Enable easyocr GPU if running on CUDA, MPS
use_gpu = device in [
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
]
self.reader = easyocr.Reader( self.reader = easyocr.Reader(
lang_list=self.options.lang, lang_list=self.options.lang,
gpu=self.options.use_gpu, gpu=use_gpu,
model_storage_directory=self.options.model_storage_directory, model_storage_directory=self.options.model_storage_directory,
download_enabled=self.options.download_enabled, download_enabled=self.options.download_enabled,
verbose=False,
) )
def __call__( def __call__(

View File

@ -17,8 +17,10 @@ from docling.datamodel.base_models import (
Page, Page,
) )
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel
from docling.utils import accelerator_utils as au
from docling.utils import layout_utils as lu from docling.utils import layout_utils as lu
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
@ -46,8 +48,11 @@ class LayoutModel(BasePageModel):
FIGURE_LABEL = DocItemLabel.PICTURE FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA FORMULA_LABEL = DocItemLabel.FORMULA
def __init__(self, artifacts_path: Path): def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary device = au.decide_device(accelerator_options.device)
self.layout_predictor = LayoutPredictor(
artifacts_path, device, accelerator_options.num_threads
)
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height): def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
MIN_INTERSECTION = 0.2 MIN_INTERSECTION = 0.2

View File

@ -6,16 +6,26 @@ from docling_core.types.doc import BoundingBox, CoordOrigin
from docling.datamodel.base_models import OcrCell, Page from docling.datamodel.base_models import OcrCell, Page
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import RapidOcrOptions from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
RapidOcrOptions,
)
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel from docling.models.base_ocr_model import BaseOcrModel
from docling.utils import accelerator_utils as au
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class RapidOcrModel(BaseOcrModel): class RapidOcrModel(BaseOcrModel):
def __init__(self, enabled: bool, options: RapidOcrOptions): def __init__(
self,
enabled: bool,
options: RapidOcrOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(enabled=enabled, options=options) super().__init__(enabled=enabled, options=options)
self.options: RapidOcrOptions self.options: RapidOcrOptions
@ -30,52 +40,21 @@ class RapidOcrModel(BaseOcrModel):
"Alternatively, Docling has support for other OCR engines. See the documentation." "Alternatively, Docling has support for other OCR engines. See the documentation."
) )
# This configuration option will be revamped while introducing device settings for all models. # Decide the accelerator devices
# For the moment we will default to auto and let onnx-runtime pick the best. device = au.decide_device(accelerator_options.device)
cls_use_cuda = True use_cuda = device == AcceleratorDevice.CUDA
rec_use_cuda = True use_dml = accelerator_options.device == AcceleratorDevice.AUTO
det_use_cuda = True intra_op_num_threads = accelerator_options.num_threads
det_use_dml = True
cls_use_dml = True
rec_use_dml = True
# # Same as Defaults in RapidOCR
# cls_use_cuda = False
# rec_use_cuda = False
# det_use_cuda = False
# det_use_dml = False
# cls_use_dml = False
# rec_use_dml = False
# # If we set everything to true onnx-runtime would automatically choose the fastest accelerator
# if self.options.device == self.options.Device.AUTO:
# cls_use_cuda = True
# rec_use_cuda = True
# det_use_cuda = True
# det_use_dml = True
# cls_use_dml = True
# rec_use_dml = True
# # If we set use_cuda to true onnx would use the cuda device available in runtime if no cuda device is available it would run on CPU.
# elif self.options.device == self.options.Device.CUDA:
# cls_use_cuda = True
# rec_use_cuda = True
# det_use_cuda = True
# # If we set use_dml to true onnx would use the dml device available in runtime if no dml device is available it would work on CPU.
# elif self.options.device == self.options.Device.DIRECTML:
# det_use_dml = True
# cls_use_dml = True
# rec_use_dml = True
self.reader = RapidOCR( self.reader = RapidOCR(
text_score=self.options.text_score, text_score=self.options.text_score,
cls_use_cuda=cls_use_cuda, cls_use_cuda=use_cuda,
rec_use_cuda=rec_use_cuda, rec_use_cuda=use_cuda,
det_use_cuda=det_use_cuda, det_use_cuda=use_cuda,
det_use_dml=det_use_dml, det_use_dml=use_dml,
cls_use_dml=cls_use_dml, cls_use_dml=use_dml,
rec_use_dml=rec_use_dml, rec_use_dml=use_dml,
intra_op_num_threads=intra_op_num_threads,
print_verbose=self.options.print_verbose, print_verbose=self.options.print_verbose,
det_model_path=self.options.det_model_path, det_model_path=self.options.det_model_path,
cls_model_path=self.options.cls_model_path, cls_model_path=self.options.cls_model_path,

View File

@ -9,15 +9,24 @@ from PIL import ImageDraw
from docling.datamodel.base_models import Page, Table, TableStructurePrediction from docling.datamodel.base_models import Page, Table, TableStructurePrediction
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import TableFormerMode, TableStructureOptions from docling.datamodel.pipeline_options import (
AcceleratorOptions,
TableFormerMode,
TableStructureOptions,
)
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel
from docling.utils import accelerator_utils as au
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
class TableStructureModel(BasePageModel): class TableStructureModel(BasePageModel):
def __init__( def __init__(
self, enabled: bool, artifacts_path: Path, options: TableStructureOptions self,
enabled: bool,
artifacts_path: Path,
options: TableStructureOptions,
accelerator_options: AcceleratorOptions,
): ):
self.options = options self.options = options
self.do_cell_matching = self.options.do_cell_matching self.do_cell_matching = self.options.do_cell_matching
@ -31,11 +40,15 @@ class TableStructureModel(BasePageModel):
# Third Party # Third Party
import docling_ibm_models.tableformer.common as c import docling_ibm_models.tableformer.common as c
device = au.decide_device(accelerator_options.device)
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json") self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
self.tm_config["model"]["save_dir"] = artifacts_path self.tm_config["model"]["save_dir"] = artifacts_path
self.tm_model_type = self.tm_config["model"]["type"] self.tm_model_type = self.tm_config["model"]["type"]
self.tf_predictor = TFPredictor(self.tm_config) self.tf_predictor = TFPredictor(
self.tm_config, device, accelerator_options.num_threads
)
self.scale = 2.0 # Scale up table input images to 144 dpi self.scale = 2.0 # Scale up table input images to 144 dpi
def draw_table_and_cells( def draw_table_and_cells(

View File

@ -38,7 +38,9 @@ _log = logging.getLogger(__name__)
class StandardPdfPipeline(PaginatedPipeline): class StandardPdfPipeline(PaginatedPipeline):
_layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt" # TODO: Revise after having the models in HF
# _layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt"
_layout_model_path = "model_artifacts/layout/"
_table_model_path = "model_artifacts/tableformer" _table_model_path = "model_artifacts/tableformer"
def __init__(self, pipeline_options: PdfPipelineOptions): def __init__(self, pipeline_options: PdfPipelineOptions):
@ -75,7 +77,8 @@ class StandardPdfPipeline(PaginatedPipeline):
# Layout model # Layout model
LayoutModel( LayoutModel(
artifacts_path=self.artifacts_path artifacts_path=self.artifacts_path
/ StandardPdfPipeline._layout_model_path / StandardPdfPipeline._layout_model_path,
accelerator_options=pipeline_options.accelerator_options,
), ),
# Table structure model # Table structure model
TableStructureModel( TableStructureModel(
@ -83,6 +86,7 @@ class StandardPdfPipeline(PaginatedPipeline):
artifacts_path=self.artifacts_path artifacts_path=self.artifacts_path
/ StandardPdfPipeline._table_model_path, / StandardPdfPipeline._table_model_path,
options=pipeline_options.table_structure_options, options=pipeline_options.table_structure_options,
accelerator_options=pipeline_options.accelerator_options,
), ),
# Page assemble # Page assemble
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)), PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
@ -98,11 +102,18 @@ class StandardPdfPipeline(PaginatedPipeline):
) -> Path: ) -> Path:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
# TODO: Revise after having the models in HF
# download_path = snapshot_download(
# repo_id="ds4sd/docling-models",
# force_download=force,
# local_dir=local_dir,
# revision="v2.0.1",
# )
download_path = snapshot_download( download_path = snapshot_download(
repo_id="ds4sd/docling-models", repo_id="ds4sd/docling-models",
force_download=force, force_download=force,
local_dir=local_dir, local_dir=local_dir,
revision="v2.0.1", revision="refs/pr/2",
) )
return Path(download_path) return Path(download_path)
@ -112,6 +123,7 @@ class StandardPdfPipeline(PaginatedPipeline):
return EasyOcrModel( return EasyOcrModel(
enabled=self.pipeline_options.do_ocr, enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options, options=self.pipeline_options.ocr_options,
accelerator_options=self.pipeline_options.accelerator_options,
) )
elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions): elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions):
return TesseractOcrCliModel( return TesseractOcrCliModel(
@ -127,6 +139,7 @@ class StandardPdfPipeline(PaginatedPipeline):
return RapidOcrModel( return RapidOcrModel(
enabled=self.pipeline_options.do_ocr, enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options, options=self.pipeline_options.ocr_options,
accelerator_options=self.pipeline_options.accelerator_options,
) )
elif isinstance(self.pipeline_options.ocr_options, OcrMacOptions): elif isinstance(self.pipeline_options.ocr_options, OcrMacOptions):
if "darwin" != sys.platform: if "darwin" != sys.platform:

View File

@ -0,0 +1,40 @@
import logging
import torch
from docling.datamodel.pipeline_options import AcceleratorDevice
_log = logging.getLogger(__name__)
def decide_device(accelerator_device: AcceleratorDevice) -> str:
r"""
Resolve the device based on the acceleration options and the available devices in the system
Rules:
1. AUTO: Check for the best available device on the system.
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
"""
cuda_index = 0
device = "cpu"
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
if accelerator_device == AcceleratorDevice.AUTO:
# TODO: Enable MPS later
if has_cuda:
device = f"cuda:{cuda_index}"
else:
if accelerator_device == AcceleratorDevice.CUDA:
if has_cuda:
device = f"cuda:{cuda_index}"
else:
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
elif accelerator_device == AcceleratorDevice.MPS:
if has_mps:
device = "mps"
else:
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
_log.info("Accelerator device: '%s'", device)
return device

View File

@ -0,0 +1,63 @@
from pathlib import Path
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
PdfPipelineOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.datamodel.settings import settings
from docling.document_converter import DocumentConverter, PdfFormatOption
def main():
input_doc = Path("./tests/data/2206.01062.pdf")
# Explicitly set the accelerator
# accelerator_options = AcceleratorOptions(
# num_threads=8, device=AcceleratorDevice.AUTO
# )
accelerator_options = AcceleratorOptions(
num_threads=8, device=AcceleratorDevice.CPU
)
# accelerator_options = AcceleratorOptions(
# num_threads=8, device=AcceleratorDevice.MPS
# )
# accelerator_options = AcceleratorOptions(
# num_threads=8, device=AcceleratorDevice.CUDA
# )
pipeline_options = PdfPipelineOptions()
pipeline_options.accelerator_options = accelerator_options
pipeline_options.do_ocr = True
pipeline_options.do_table_structure = True
pipeline_options.table_structure_options.do_cell_matching = True
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options,
)
}
)
# Enable the profiling to measure the time spent
settings.debug.profile_pipeline_timings = True
# Convert the document
conversion_result = converter.convert(input_doc)
doc = conversion_result.document
# List with total time per document
doc_conversion_secs = conversion_result.timings["pipeline_total"].times
md = doc.export_to_markdown()
print(md)
print(f"Conversion secs: {doc_conversion_secs}")
if __name__ == "__main__":
main()

View File

@ -76,6 +76,7 @@ nav:
- "Table export": examples/export_tables.py - "Table export": examples/export_tables.py
- "Multimodal export": examples/export_multimodal.py - "Multimodal export": examples/export_multimodal.py
- "Force full page OCR": examples/full_page_ocr.py - "Force full page OCR": examples/full_page_ocr.py
- "Accelerator options": examples/run_with_acclerators.py
- RAG / QA: - RAG / QA:
- "RAG with LlamaIndex 🦙": examples/rag_llamaindex.ipynb - "RAG with LlamaIndex 🦙": examples/rag_llamaindex.ipynb
- "RAG with LangChain 🦜🔗": examples/rag_langchain.ipynb - "RAG with LangChain 🦜🔗": examples/rag_langchain.ipynb

View File

@ -1,3 +1,4 @@
import os
from pathlib import Path from pathlib import Path
import pytest import pytest
@ -5,7 +6,12 @@ import pytest
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.datamodel.base_models import ConversionStatus, InputFormat from docling.datamodel.base_models import ConversionStatus, InputFormat
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import PdfPipelineOptions, TableFormerMode from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
PdfPipelineOptions,
TableFormerMode,
)
from docling.document_converter import DocumentConverter, PdfFormatOption from docling.document_converter import DocumentConverter, PdfFormatOption
@ -35,6 +41,61 @@ def get_converters_with_table_options():
yield converter yield converter
def test_accelerator_options():
# Check the default options
ao = AcceleratorOptions()
assert ao.num_threads == 4, "Wrong default num_threads"
assert ao.device == AcceleratorDevice.AUTO, "Wrong default device"
# Use API
ao2 = AcceleratorOptions(num_threads=2, device=AcceleratorDevice.MPS)
ao3 = AcceleratorOptions(num_threads=3, device=AcceleratorDevice.CUDA)
assert ao2.num_threads == 2
assert ao2.device == AcceleratorDevice.MPS
assert ao3.num_threads == 3
assert ao3.device == AcceleratorDevice.CUDA
# Use envvars (regular + alternative) and default values
os.environ["OMP_NUM_THREADS"] = "1"
ao.__init__()
assert ao.num_threads == 1
assert ao.device == AcceleratorDevice.AUTO
os.environ["DOCLING_DEVICE"] = "cpu"
ao.__init__()
assert ao.device == AcceleratorDevice.CPU
assert ao.num_threads == 1
# Use envvars and override in init
os.environ["DOCLING_DEVICE"] = "cpu"
ao4 = AcceleratorOptions(num_threads=5, device=AcceleratorDevice.MPS)
assert ao4.num_threads == 5
assert ao4.device == AcceleratorDevice.MPS
# Use regular and alternative envvar
os.environ["DOCLING_NUM_THREADS"] = "2"
ao5 = AcceleratorOptions()
assert ao5.num_threads == 2
assert ao5.device == AcceleratorDevice.CPU
# Use wrong values
is_exception = False
try:
os.environ["DOCLING_DEVICE"] = "wrong"
ao5.__init__()
except Exception as ex:
print(ex)
is_exception = True
assert is_exception
# Use misformatted alternative envvar
del os.environ["DOCLING_NUM_THREADS"]
del os.environ["DOCLING_DEVICE"]
os.environ["OMP_NUM_THREADS"] = "wrong"
ao6 = AcceleratorOptions()
assert ao6.num_threads == 4
assert ao6.device == AcceleratorDevice.AUTO
def test_e2e_conversions(test_doc_path): def test_e2e_conversions(test_doc_path):
for converter in get_converters_with_table_options(): for converter in get_converters_with_table_options():
print(f"converting {test_doc_path}") print(f"converting {test_doc_path}")