mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
feat(Accelerator): Introduce options to control the num_threads and device from API, envvars, CLI.
- Introduce the AcceleratorOptions, AcceleratorDevice and use them to set the device where the models run. - Introduce the accelerator_utils with function to decide the device and resolve the AUTO setting. - Refactor the way how the docling-ibm-models are called to match the new init signature of models. - Translate the accelerator options to the specific inputs for third-party models. - Extend the docling CLI with parameters to set the num_threads and device. - Add new unit tests. - Write new example how to use the accelerator options. Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
78fad801fe
commit
3bb7df66ca
@ -24,6 +24,8 @@ from docling.datamodel.base_models import (
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
EasyOcrOptions,
|
||||
OcrMacOptions,
|
||||
OcrOptions,
|
||||
@ -241,6 +243,10 @@ def convert(
|
||||
help="Show version information.",
|
||||
),
|
||||
] = None,
|
||||
num_threads: Annotated[int, typer.Option(..., help="Number of threads")] = 4,
|
||||
device: Annotated[
|
||||
AcceleratorDevice, typer.Option(..., help="Accelerator device")
|
||||
] = AcceleratorDevice.AUTO,
|
||||
):
|
||||
if verbose == 0:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
@ -299,7 +305,9 @@ def convert(
|
||||
if ocr_lang_list is not None:
|
||||
ocr_options.lang = ocr_lang_list
|
||||
|
||||
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
||||
pipeline_options = PdfPipelineOptions(
|
||||
accelerator_options=accelerator_options,
|
||||
do_ocr=ocr,
|
||||
ocr_options=ocr_options,
|
||||
do_table_structure=True,
|
||||
|
@ -1,8 +1,62 @@
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AcceleratorDevice(str, Enum):
|
||||
"""Devices to run model inference"""
|
||||
|
||||
AUTO = "auto"
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
MPS = "mps"
|
||||
|
||||
|
||||
class AcceleratorOptions(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
|
||||
)
|
||||
|
||||
# num_threads: int = Field(default=4, validation_alias="omp_num_threads")
|
||||
num_threads: int = 4
|
||||
device: AcceleratorDevice = AcceleratorDevice.AUTO
|
||||
|
||||
def __init__(self, **data):
|
||||
r"""
|
||||
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
|
||||
The alternative envvar is used only if it is valid and the regular envvar is not set.
|
||||
|
||||
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
|
||||
the same functionality. In case the alias envvar is set and the user tries to override the
|
||||
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
|
||||
as an extra input instead of simply overwriting the evvar value for that parameter.
|
||||
"""
|
||||
input_num_threads = data.get("num_threads")
|
||||
|
||||
# Check if to set the num_threads from the alternative envvar
|
||||
if input_num_threads is None:
|
||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
||||
if docling_num_threads is None and omp_num_threads is not None:
|
||||
try:
|
||||
data["num_threads"] = int(omp_num_threads)
|
||||
except ValueError:
|
||||
_log.error(
|
||||
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
|
||||
omp_num_threads,
|
||||
)
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class TableFormerMode(str, Enum):
|
||||
@ -78,7 +132,16 @@ class EasyOcrOptions(OcrOptions):
|
||||
|
||||
kind: Literal["easyocr"] = "easyocr"
|
||||
lang: List[str] = ["fr", "de", "es", "en"]
|
||||
use_gpu: bool = True # same default as easyocr.Reader
|
||||
use_gpu: bool = Field(
|
||||
default=True,
|
||||
deprecated=(
|
||||
"Field `use_gpu` is deprecated. "
|
||||
"When `use_gpu == True and accelerator_options.device == AcceleratorDevice.CUDA` "
|
||||
"then the GPU is used to run EasyOCR. "
|
||||
"When `use_gpu == False`, EasyOCR runs in CPU"
|
||||
),
|
||||
)
|
||||
|
||||
model_storage_directory: Optional[str] = None
|
||||
download_enabled: bool = True # same default as easyocr.Reader
|
||||
|
||||
@ -132,6 +195,7 @@ class PipelineOptions(BaseModel):
|
||||
create_legacy_output: bool = (
|
||||
True # This defautl will be set to False on a future version of docling
|
||||
)
|
||||
accelerator_options: AcceleratorOptions = AcceleratorOptions()
|
||||
|
||||
|
||||
class PdfPipelineOptions(PipelineOptions):
|
||||
|
@ -7,16 +7,26 @@ from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
|
||||
from docling.datamodel.base_models import Cell, OcrCell, Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import EasyOcrOptions
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
EasyOcrOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyOcrModel(BaseOcrModel):
|
||||
def __init__(self, enabled: bool, options: EasyOcrOptions):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
options: EasyOcrOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
self.options: EasyOcrOptions
|
||||
|
||||
@ -31,11 +41,20 @@ class EasyOcrModel(BaseOcrModel):
|
||||
"Alternatively, Docling has support for other OCR engines. See the documentation."
|
||||
)
|
||||
|
||||
use_gpu = False
|
||||
if self.options.use_gpu:
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
# Enable easyocr GPU if running on CUDA, MPS
|
||||
use_gpu = device in [
|
||||
AcceleratorDevice.CUDA,
|
||||
AcceleratorDevice.MPS,
|
||||
]
|
||||
self.reader = easyocr.Reader(
|
||||
lang_list=self.options.lang,
|
||||
gpu=self.options.use_gpu,
|
||||
gpu=use_gpu,
|
||||
model_storage_directory=self.options.model_storage_directory,
|
||||
download_enabled=self.options.download_enabled,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
@ -17,8 +17,10 @@ from docling.datamodel.base_models import (
|
||||
Page,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils import layout_utils as lu
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
@ -46,8 +48,11 @@ class LayoutModel(BasePageModel):
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
|
||||
def __init__(self, artifacts_path: Path):
|
||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
artifacts_path, device, accelerator_options.num_threads
|
||||
)
|
||||
|
||||
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
|
||||
MIN_INTERSECTION = 0.2
|
||||
|
@ -6,16 +6,26 @@ from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
|
||||
from docling.datamodel.base_models import OcrCell, Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import RapidOcrOptions
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
RapidOcrOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RapidOcrModel(BaseOcrModel):
|
||||
def __init__(self, enabled: bool, options: RapidOcrOptions):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
options: RapidOcrOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
self.options: RapidOcrOptions
|
||||
|
||||
@ -30,52 +40,21 @@ class RapidOcrModel(BaseOcrModel):
|
||||
"Alternatively, Docling has support for other OCR engines. See the documentation."
|
||||
)
|
||||
|
||||
# This configuration option will be revamped while introducing device settings for all models.
|
||||
# For the moment we will default to auto and let onnx-runtime pick the best.
|
||||
cls_use_cuda = True
|
||||
rec_use_cuda = True
|
||||
det_use_cuda = True
|
||||
det_use_dml = True
|
||||
cls_use_dml = True
|
||||
rec_use_dml = True
|
||||
|
||||
# # Same as Defaults in RapidOCR
|
||||
# cls_use_cuda = False
|
||||
# rec_use_cuda = False
|
||||
# det_use_cuda = False
|
||||
# det_use_dml = False
|
||||
# cls_use_dml = False
|
||||
# rec_use_dml = False
|
||||
|
||||
# # If we set everything to true onnx-runtime would automatically choose the fastest accelerator
|
||||
# if self.options.device == self.options.Device.AUTO:
|
||||
# cls_use_cuda = True
|
||||
# rec_use_cuda = True
|
||||
# det_use_cuda = True
|
||||
# det_use_dml = True
|
||||
# cls_use_dml = True
|
||||
# rec_use_dml = True
|
||||
|
||||
# # If we set use_cuda to true onnx would use the cuda device available in runtime if no cuda device is available it would run on CPU.
|
||||
# elif self.options.device == self.options.Device.CUDA:
|
||||
# cls_use_cuda = True
|
||||
# rec_use_cuda = True
|
||||
# det_use_cuda = True
|
||||
|
||||
# # If we set use_dml to true onnx would use the dml device available in runtime if no dml device is available it would work on CPU.
|
||||
# elif self.options.device == self.options.Device.DIRECTML:
|
||||
# det_use_dml = True
|
||||
# cls_use_dml = True
|
||||
# rec_use_dml = True
|
||||
# Decide the accelerator devices
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
use_cuda = device == AcceleratorDevice.CUDA
|
||||
use_dml = accelerator_options.device == AcceleratorDevice.AUTO
|
||||
intra_op_num_threads = accelerator_options.num_threads
|
||||
|
||||
self.reader = RapidOCR(
|
||||
text_score=self.options.text_score,
|
||||
cls_use_cuda=cls_use_cuda,
|
||||
rec_use_cuda=rec_use_cuda,
|
||||
det_use_cuda=det_use_cuda,
|
||||
det_use_dml=det_use_dml,
|
||||
cls_use_dml=cls_use_dml,
|
||||
rec_use_dml=rec_use_dml,
|
||||
cls_use_cuda=use_cuda,
|
||||
rec_use_cuda=use_cuda,
|
||||
det_use_cuda=use_cuda,
|
||||
det_use_dml=use_dml,
|
||||
cls_use_dml=use_dml,
|
||||
rec_use_dml=use_dml,
|
||||
intra_op_num_threads=intra_op_num_threads,
|
||||
print_verbose=self.options.print_verbose,
|
||||
det_model_path=self.options.det_model_path,
|
||||
cls_model_path=self.options.cls_model_path,
|
||||
|
@ -9,15 +9,24 @@ from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import TableFormerMode, TableStructureOptions
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorOptions,
|
||||
TableFormerMode,
|
||||
TableStructureOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils import accelerator_utils as au
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
|
||||
class TableStructureModel(BasePageModel):
|
||||
def __init__(
|
||||
self, enabled: bool, artifacts_path: Path, options: TableStructureOptions
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Path,
|
||||
options: TableStructureOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
self.options = options
|
||||
self.do_cell_matching = self.options.do_cell_matching
|
||||
@ -31,11 +40,15 @@ class TableStructureModel(BasePageModel):
|
||||
# Third Party
|
||||
import docling_ibm_models.tableformer.common as c
|
||||
|
||||
device = au.decide_device(accelerator_options.device)
|
||||
|
||||
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
||||
self.tm_config["model"]["save_dir"] = artifacts_path
|
||||
self.tm_model_type = self.tm_config["model"]["type"]
|
||||
|
||||
self.tf_predictor = TFPredictor(self.tm_config)
|
||||
self.tf_predictor = TFPredictor(
|
||||
self.tm_config, device, accelerator_options.num_threads
|
||||
)
|
||||
self.scale = 2.0 # Scale up table input images to 144 dpi
|
||||
|
||||
def draw_table_and_cells(
|
||||
|
@ -38,7 +38,9 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StandardPdfPipeline(PaginatedPipeline):
|
||||
_layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt"
|
||||
# TODO: Revise after having the models in HF
|
||||
# _layout_model_path = "model_artifacts/layout/beehive_v0.0.5_pt"
|
||||
_layout_model_path = "model_artifacts/layout/"
|
||||
_table_model_path = "model_artifacts/tableformer"
|
||||
|
||||
def __init__(self, pipeline_options: PdfPipelineOptions):
|
||||
@ -75,7 +77,8 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
# Layout model
|
||||
LayoutModel(
|
||||
artifacts_path=self.artifacts_path
|
||||
/ StandardPdfPipeline._layout_model_path
|
||||
/ StandardPdfPipeline._layout_model_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
# Table structure model
|
||||
TableStructureModel(
|
||||
@ -83,6 +86,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
artifacts_path=self.artifacts_path
|
||||
/ StandardPdfPipeline._table_model_path,
|
||||
options=pipeline_options.table_structure_options,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
# Page assemble
|
||||
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
|
||||
@ -98,11 +102,18 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# TODO: Revise after having the models in HF
|
||||
# download_path = snapshot_download(
|
||||
# repo_id="ds4sd/docling-models",
|
||||
# force_download=force,
|
||||
# local_dir=local_dir,
|
||||
# revision="v2.0.1",
|
||||
# )
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/docling-models",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v2.0.1",
|
||||
revision="refs/pr/2",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
@ -112,6 +123,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
return EasyOcrModel(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
options=self.pipeline_options.ocr_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions):
|
||||
return TesseractOcrCliModel(
|
||||
@ -127,6 +139,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
return RapidOcrModel(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
options=self.pipeline_options.ocr_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
elif isinstance(self.pipeline_options.ocr_options, OcrMacOptions):
|
||||
if "darwin" != sys.platform:
|
||||
|
40
docling/utils/accelerator_utils.py
Normal file
40
docling/utils/accelerator_utils.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from docling.datamodel.pipeline_options import AcceleratorDevice
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decide_device(accelerator_device: AcceleratorDevice) -> str:
|
||||
r"""
|
||||
Resolve the device based on the acceleration options and the available devices in the system
|
||||
Rules:
|
||||
1. AUTO: Check for the best available device on the system.
|
||||
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
|
||||
"""
|
||||
cuda_index = 0
|
||||
device = "cpu"
|
||||
|
||||
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
|
||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||
|
||||
if accelerator_device == AcceleratorDevice.AUTO:
|
||||
# TODO: Enable MPS later
|
||||
if has_cuda:
|
||||
device = f"cuda:{cuda_index}"
|
||||
else:
|
||||
if accelerator_device == AcceleratorDevice.CUDA:
|
||||
if has_cuda:
|
||||
device = f"cuda:{cuda_index}"
|
||||
else:
|
||||
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
|
||||
elif accelerator_device == AcceleratorDevice.MPS:
|
||||
if has_mps:
|
||||
device = "mps"
|
||||
else:
|
||||
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
|
||||
|
||||
_log.info("Accelerator device: '%s'", device)
|
||||
return device
|
63
docs/examples/run_with_accelerator.py
Normal file
63
docs/examples/run_with_accelerator.py
Normal file
@ -0,0 +1,63 @@
|
||||
from pathlib import Path
|
||||
|
||||
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
PdfPipelineOptions,
|
||||
TesseractCliOcrOptions,
|
||||
TesseractOcrOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
|
||||
|
||||
def main():
|
||||
input_doc = Path("./tests/data/2206.01062.pdf")
|
||||
|
||||
# Explicitly set the accelerator
|
||||
# accelerator_options = AcceleratorOptions(
|
||||
# num_threads=8, device=AcceleratorDevice.AUTO
|
||||
# )
|
||||
accelerator_options = AcceleratorOptions(
|
||||
num_threads=8, device=AcceleratorDevice.CPU
|
||||
)
|
||||
# accelerator_options = AcceleratorOptions(
|
||||
# num_threads=8, device=AcceleratorDevice.MPS
|
||||
# )
|
||||
# accelerator_options = AcceleratorOptions(
|
||||
# num_threads=8, device=AcceleratorDevice.CUDA
|
||||
# )
|
||||
|
||||
pipeline_options = PdfPipelineOptions()
|
||||
pipeline_options.accelerator_options = accelerator_options
|
||||
pipeline_options.do_ocr = True
|
||||
pipeline_options.do_table_structure = True
|
||||
pipeline_options.table_structure_options.do_cell_matching = True
|
||||
|
||||
converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: PdfFormatOption(
|
||||
pipeline_options=pipeline_options,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Enable the profiling to measure the time spent
|
||||
settings.debug.profile_pipeline_timings = True
|
||||
|
||||
# Convert the document
|
||||
conversion_result = converter.convert(input_doc)
|
||||
doc = conversion_result.document
|
||||
|
||||
# List with total time per document
|
||||
doc_conversion_secs = conversion_result.timings["pipeline_total"].times
|
||||
|
||||
md = doc.export_to_markdown()
|
||||
print(md)
|
||||
print(f"Conversion secs: {doc_conversion_secs}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -76,6 +76,7 @@ nav:
|
||||
- "Table export": examples/export_tables.py
|
||||
- "Multimodal export": examples/export_multimodal.py
|
||||
- "Force full page OCR": examples/full_page_ocr.py
|
||||
- "Accelerator options": examples/run_with_acclerators.py
|
||||
- RAG / QA:
|
||||
- "RAG with LlamaIndex 🦙": examples/rag_llamaindex.ipynb
|
||||
- "RAG with LangChain 🦜🔗": examples/rag_langchain.ipynb
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@ -5,7 +6,12 @@ import pytest
|
||||
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
|
||||
from docling.datamodel.base_models import ConversionStatus, InputFormat
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions, TableFormerMode
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
PdfPipelineOptions,
|
||||
TableFormerMode,
|
||||
)
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
|
||||
|
||||
@ -35,6 +41,61 @@ def get_converters_with_table_options():
|
||||
yield converter
|
||||
|
||||
|
||||
def test_accelerator_options():
|
||||
# Check the default options
|
||||
ao = AcceleratorOptions()
|
||||
assert ao.num_threads == 4, "Wrong default num_threads"
|
||||
assert ao.device == AcceleratorDevice.AUTO, "Wrong default device"
|
||||
|
||||
# Use API
|
||||
ao2 = AcceleratorOptions(num_threads=2, device=AcceleratorDevice.MPS)
|
||||
ao3 = AcceleratorOptions(num_threads=3, device=AcceleratorDevice.CUDA)
|
||||
assert ao2.num_threads == 2
|
||||
assert ao2.device == AcceleratorDevice.MPS
|
||||
assert ao3.num_threads == 3
|
||||
assert ao3.device == AcceleratorDevice.CUDA
|
||||
|
||||
# Use envvars (regular + alternative) and default values
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
ao.__init__()
|
||||
assert ao.num_threads == 1
|
||||
assert ao.device == AcceleratorDevice.AUTO
|
||||
os.environ["DOCLING_DEVICE"] = "cpu"
|
||||
ao.__init__()
|
||||
assert ao.device == AcceleratorDevice.CPU
|
||||
assert ao.num_threads == 1
|
||||
|
||||
# Use envvars and override in init
|
||||
os.environ["DOCLING_DEVICE"] = "cpu"
|
||||
ao4 = AcceleratorOptions(num_threads=5, device=AcceleratorDevice.MPS)
|
||||
assert ao4.num_threads == 5
|
||||
assert ao4.device == AcceleratorDevice.MPS
|
||||
|
||||
# Use regular and alternative envvar
|
||||
os.environ["DOCLING_NUM_THREADS"] = "2"
|
||||
ao5 = AcceleratorOptions()
|
||||
assert ao5.num_threads == 2
|
||||
assert ao5.device == AcceleratorDevice.CPU
|
||||
|
||||
# Use wrong values
|
||||
is_exception = False
|
||||
try:
|
||||
os.environ["DOCLING_DEVICE"] = "wrong"
|
||||
ao5.__init__()
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
is_exception = True
|
||||
assert is_exception
|
||||
|
||||
# Use misformatted alternative envvar
|
||||
del os.environ["DOCLING_NUM_THREADS"]
|
||||
del os.environ["DOCLING_DEVICE"]
|
||||
os.environ["OMP_NUM_THREADS"] = "wrong"
|
||||
ao6 = AcceleratorOptions()
|
||||
assert ao6.num_threads == 4
|
||||
assert ao6.device == AcceleratorDevice.AUTO
|
||||
|
||||
|
||||
def test_e2e_conversions(test_doc_path):
|
||||
for converter in get_converters_with_table_options():
|
||||
print(f"converting {test_doc_path}")
|
||||
|
Loading…
Reference in New Issue
Block a user