mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-01 15:02:21 +00:00
Pydantic field validator and comment restored.
Signed-off-by: ahn <ahn@zurich.ibm.com>
This commit is contained in:
parent
b9668877be
commit
dd0728e646
@ -41,7 +41,7 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
num_threads: int = 4
|
num_threads: int = 4
|
||||||
device: str = "auto"
|
device: str = "auto"
|
||||||
|
|
||||||
@validator("device")
|
@field_validator("device")
|
||||||
def validate_device(cls, value):
|
def validate_device(cls, value):
|
||||||
# "auto", "cpu", "cuda", "mps", or "cuda:N"
|
# "auto", "cpu", "cuda", "mps", or "cuda:N"
|
||||||
if value in {d.value for d in AcceleratorDevice} or re.match(
|
if value in {d.value for d in AcceleratorDevice} or re.match(
|
||||||
@ -55,6 +55,15 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_alternative_envvars(cls, data: Any) -> Any:
|
def check_alternative_envvars(cls, data: Any) -> Any:
|
||||||
|
r"""
|
||||||
|
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
|
||||||
|
The alternative envvar is used only if it is valid and the regular envvar is not set.
|
||||||
|
|
||||||
|
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
|
||||||
|
the same functionality. In case the alias envvar is set and the user tries to override the
|
||||||
|
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
|
||||||
|
as an extra input instead of simply overwriting the evvar value for that parameter.
|
||||||
|
"""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
input_num_threads = data.get("num_threads")
|
input_num_threads = data.get("num_threads")
|
||||||
if input_num_threads is None:
|
if input_num_threads is None:
|
||||||
|
@ -18,7 +18,6 @@ from docling.datamodel.settings import settings
|
|||||||
from docling.models.base_ocr_model import BaseOcrModel
|
from docling.models.base_ocr_model import BaseOcrModel
|
||||||
from docling.utils.accelerator_utils import decide_device
|
from docling.utils.accelerator_utils import decide_device
|
||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
from docling.utils.utils import download_url_with_progress
|
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -56,6 +55,7 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
for x in [
|
for x in [
|
||||||
AcceleratorDevice.CUDA.value,
|
AcceleratorDevice.CUDA.value,
|
||||||
AcceleratorDevice.MPS.value,
|
AcceleratorDevice.MPS.value,
|
||||||
|
"cuda:",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -82,40 +82,6 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def download_models(
|
|
||||||
detection_models: List[str] = ["craft"],
|
|
||||||
recognition_models: List[str] = ["english_g2", "latin_g2"],
|
|
||||||
local_dir: Optional[Path] = None,
|
|
||||||
force: bool = False,
|
|
||||||
progress: bool = False,
|
|
||||||
) -> Path:
|
|
||||||
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
|
|
||||||
from easyocr.config import detection_models as det_models_dict
|
|
||||||
from easyocr.config import recognition_models as rec_models_dict
|
|
||||||
|
|
||||||
if local_dir is None:
|
|
||||||
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
|
|
||||||
|
|
||||||
local_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Collect models to download
|
|
||||||
download_list = []
|
|
||||||
for model_name in detection_models:
|
|
||||||
if model_name in det_models_dict:
|
|
||||||
download_list.append(det_models_dict[model_name])
|
|
||||||
for model_name in recognition_models:
|
|
||||||
if model_name in rec_models_dict["gen2"]:
|
|
||||||
download_list.append(rec_models_dict["gen2"][model_name])
|
|
||||||
|
|
||||||
# Download models
|
|
||||||
for model_details in download_list:
|
|
||||||
buf = download_url_with_progress(model_details["url"], progress=progress)
|
|
||||||
with zipfile.ZipFile(buf, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(local_dir)
|
|
||||||
|
|
||||||
return local_dir
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
|
@ -31,9 +31,7 @@ def main():
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
# easyocr doesnt support cuda:N allocation
|
# easyocr doesnt support cuda:N allocation
|
||||||
# accelerator_options = AcceleratorOptions(
|
# accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:0")
|
||||||
# num_threads=8, device="cuda:1"
|
|
||||||
# )
|
|
||||||
|
|
||||||
pipeline_options = PdfPipelineOptions()
|
pipeline_options = PdfPipelineOptions()
|
||||||
pipeline_options.accelerator_options = accelerator_options
|
pipeline_options.accelerator_options = accelerator_options
|
||||||
@ -59,8 +57,8 @@ def main():
|
|||||||
# List with total time per document
|
# List with total time per document
|
||||||
doc_conversion_secs = conversion_result.timings["pipeline_total"].times
|
doc_conversion_secs = conversion_result.timings["pipeline_total"].times
|
||||||
|
|
||||||
md = doc.export_to_markdown()
|
# md = doc.export_to_markdown()
|
||||||
print(md)
|
# print(md)
|
||||||
print(f"Conversion secs: {doc_conversion_secs}")
|
print(f"Conversion secs: {doc_conversion_secs}")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user