mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-31 14:34:40 +00:00
Fixed rebased issues
Signed-off-by: ahn <ahn@zurich.ibm.com>
This commit is contained in:
parent
a8d1cdfaa5
commit
4b8396cde3
@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import (
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
@ -66,6 +67,7 @@ class AcceleratorOptions(BaseSettings):
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
input_num_threads = data.get("num_threads")
|
||||
# Check if to set the num_threads from the alternative envvar
|
||||
if input_num_threads is None:
|
||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||
omp_num_threads = os.getenv("OMP_NUM_THREADS")
|
||||
|
@ -18,6 +18,7 @@ from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
from docling.utils.utils import download_url_with_progress
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@ -81,6 +82,40 @@ class EasyOcrModel(BaseOcrModel):
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
detection_models: List[str] = ["craft"],
|
||||
recognition_models: List[str] = ["english_g2", "latin_g2"],
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
|
||||
from easyocr.config import detection_models as det_models_dict
|
||||
from easyocr.config import recognition_models as rec_models_dict
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
|
||||
|
||||
local_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Collect models to download
|
||||
download_list = []
|
||||
for model_name in detection_models:
|
||||
if model_name in det_models_dict:
|
||||
download_list.append(det_models_dict[model_name])
|
||||
for model_name in recognition_models:
|
||||
if model_name in rec_models_dict["gen2"]:
|
||||
download_list.append(rec_models_dict["gen2"][model_name])
|
||||
|
||||
# Download models
|
||||
for model_details in download_list:
|
||||
buf = download_url_with_progress(model_details["url"], progress=progress)
|
||||
with zipfile.ZipFile(buf, "r") as zip_ref:
|
||||
zip_ref.extractall(local_dir)
|
||||
|
||||
return local_dir
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
@ -31,11 +31,11 @@ def main():
|
||||
# )
|
||||
|
||||
# easyocr doesnt support cuda:N allocation, defaults to cuda:0
|
||||
# accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:1")
|
||||
accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:1")
|
||||
|
||||
pipeline_options = PdfPipelineOptions()
|
||||
pipeline_options.accelerator_options = accelerator_options
|
||||
pipeline_options.do_ocr = True
|
||||
pipeline_options.do_ocr = False
|
||||
pipeline_options.do_table_structure = True
|
||||
pipeline_options.table_structure_options.do_cell_matching = True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user