mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-02 15:32:30 +00:00
fix artifacts path
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
94751a78f4
commit
18aad34d67
@ -61,5 +61,7 @@ class AppSettings(BaseSettings):
|
|||||||
perf: BatchConcurrencySettings
|
perf: BatchConcurrencySettings
|
||||||
debug: DebugSettings
|
debug: DebugSettings
|
||||||
|
|
||||||
|
cache_dir: Path = Path.home() / ".cache" / "docling"
|
||||||
|
|
||||||
|
|
||||||
settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings())
|
settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings())
|
||||||
|
@ -61,13 +61,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|||||||
Processes the given batch of elements and enriches them with predictions.
|
Processes the given batch of elements and enriches them with predictions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_model_repo_folder = "CodeFormula"
|
||||||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
||||||
expansion_factor = 0.03
|
expansion_factor = 0.03
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
artifacts_path: Optional[Union[Path, str]],
|
artifacts_path: Optional[Path],
|
||||||
options: CodeFormulaModelOptions,
|
options: CodeFormulaModelOptions,
|
||||||
accelerator_options: AcceleratorOptions,
|
accelerator_options: AcceleratorOptions,
|
||||||
):
|
):
|
||||||
@ -98,7 +99,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|||||||
if artifacts_path is None:
|
if artifacts_path is None:
|
||||||
artifacts_path = self.download_models_hf()
|
artifacts_path = self.download_models_hf()
|
||||||
else:
|
else:
|
||||||
artifacts_path = Path(artifacts_path)
|
artifacts_path = artifacts_path / self._model_repo_folder
|
||||||
|
|
||||||
self.code_formula_model = CodeFormulaPredictor(
|
self.code_formula_model = CodeFormulaPredictor(
|
||||||
artifacts_path=artifacts_path,
|
artifacts_path=artifacts_path,
|
||||||
|
@ -55,12 +55,13 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|||||||
Processes a batch of elements and adds classification annotations.
|
Processes a batch of elements and adds classification annotations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_model_repo_folder = "DocumentFigureClassifier"
|
||||||
images_scale = 2
|
images_scale = 2
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
artifacts_path: Optional[Union[Path, str]],
|
artifacts_path: Optional[Path],
|
||||||
options: DocumentPictureClassifierOptions,
|
options: DocumentPictureClassifierOptions,
|
||||||
accelerator_options: AcceleratorOptions,
|
accelerator_options: AcceleratorOptions,
|
||||||
):
|
):
|
||||||
@ -90,7 +91,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|||||||
if artifacts_path is None:
|
if artifacts_path is None:
|
||||||
artifacts_path = self.download_models_hf()
|
artifacts_path = self.download_models_hf()
|
||||||
else:
|
else:
|
||||||
artifacts_path = Path(artifacts_path)
|
artifacts_path = artifacts_path / self._model_repo_folder
|
||||||
|
|
||||||
self.document_picture_classifier = DocumentFigureClassifierPredictor(
|
self.document_picture_classifier = DocumentFigureClassifierPredictor(
|
||||||
artifacts_path=artifacts_path,
|
artifacts_path=artifacts_path,
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
from docling_core.types.doc import DocItemLabel
|
from docling_core.types.doc import DocItemLabel
|
||||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||||
@ -21,6 +22,8 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class LayoutModel(BasePageModel):
|
class LayoutModel(BasePageModel):
|
||||||
|
_model_repo_folder = "docling-models"
|
||||||
|
_model_path = "model_artifacts/layout"
|
||||||
|
|
||||||
TEXT_ELEM_LABELS = [
|
TEXT_ELEM_LABELS = [
|
||||||
DocItemLabel.TEXT,
|
DocItemLabel.TEXT,
|
||||||
@ -42,15 +45,53 @@ class LayoutModel(BasePageModel):
|
|||||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||||
|
|
||||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
def __init__(
|
||||||
|
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions
|
||||||
|
):
|
||||||
device = decide_device(accelerator_options.device)
|
device = decide_device(accelerator_options.device)
|
||||||
|
|
||||||
|
if artifacts_path is None:
|
||||||
|
artifacts_path = self.download_models_hf() / self._model_path
|
||||||
|
else:
|
||||||
|
# will become the default in the future
|
||||||
|
if (artifacts_path / self._model_repo_folder).exists():
|
||||||
|
artifacts_path = (
|
||||||
|
artifacts_path / self._model_repo_folder / self._model_path
|
||||||
|
)
|
||||||
|
elif (artifacts_path / self._model_path).exists():
|
||||||
|
warnings.warn(
|
||||||
|
"The usage of artifacts_path containing directly "
|
||||||
|
f"{self._model_path} is deprecated. Please point "
|
||||||
|
"the artifacts_path to the parent containing "
|
||||||
|
f"the {self._model_repo_folder} folder.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=3,
|
||||||
|
)
|
||||||
|
artifacts_path = artifacts_path / self._model_path
|
||||||
|
|
||||||
self.layout_predictor = LayoutPredictor(
|
self.layout_predictor = LayoutPredictor(
|
||||||
artifact_path=str(artifacts_path),
|
artifact_path=str(artifacts_path),
|
||||||
device=device,
|
device=device,
|
||||||
num_threads=accelerator_options.num_threads,
|
num_threads=accelerator_options.num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download_models_hf(
|
||||||
|
local_dir: Optional[Path] = None, force: bool = False
|
||||||
|
) -> Path:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import disable_progress_bars
|
||||||
|
|
||||||
|
disable_progress_bars()
|
||||||
|
download_path = snapshot_download(
|
||||||
|
repo_id="ds4sd/docling-models",
|
||||||
|
force_download=force,
|
||||||
|
local_dir=local_dir,
|
||||||
|
revision="v2.1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
return Path(download_path)
|
||||||
|
|
||||||
def draw_clusters_and_cells_side_by_side(
|
def draw_clusters_and_cells_side_by_side(
|
||||||
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
||||||
):
|
):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
||||||
@ -22,10 +23,13 @@ from docling.utils.profiling import TimeRecorder
|
|||||||
|
|
||||||
|
|
||||||
class TableStructureModel(BasePageModel):
|
class TableStructureModel(BasePageModel):
|
||||||
|
_model_repo_folder = "docling-models"
|
||||||
|
_model_path = "model_artifacts/tableformer"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
artifacts_path: Path,
|
artifacts_path: Optional[Path],
|
||||||
options: TableStructureOptions,
|
options: TableStructureOptions,
|
||||||
accelerator_options: AcceleratorOptions,
|
accelerator_options: AcceleratorOptions,
|
||||||
):
|
):
|
||||||
@ -35,6 +39,26 @@ class TableStructureModel(BasePageModel):
|
|||||||
|
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
|
|
||||||
|
if artifacts_path is None:
|
||||||
|
artifacts_path = self.download_models_hf() / self._model_path
|
||||||
|
else:
|
||||||
|
# will become the default in the future
|
||||||
|
if (artifacts_path / self._model_repo_folder).exists():
|
||||||
|
artifacts_path = (
|
||||||
|
artifacts_path / self._model_repo_folder / self._model_path
|
||||||
|
)
|
||||||
|
elif (artifacts_path / self._model_path).exists():
|
||||||
|
warnings.warn(
|
||||||
|
"The usage of artifacts_path containing directly "
|
||||||
|
f"{self._model_path} is deprecated. Please point "
|
||||||
|
"the artifacts_path to the parent containing "
|
||||||
|
f"the {self._model_repo_folder} folder.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=3,
|
||||||
|
)
|
||||||
|
artifacts_path = artifacts_path / self._model_path
|
||||||
|
|
||||||
if self.mode == TableFormerMode.ACCURATE:
|
if self.mode == TableFormerMode.ACCURATE:
|
||||||
artifacts_path = artifacts_path / "accurate"
|
artifacts_path = artifacts_path / "accurate"
|
||||||
else:
|
else:
|
||||||
@ -58,6 +82,23 @@ class TableStructureModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
self.scale = 2.0 # Scale up table input images to 144 dpi
|
self.scale = 2.0 # Scale up table input images to 144 dpi
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download_models_hf(
|
||||||
|
local_dir: Optional[Path] = None, force: bool = False
|
||||||
|
) -> Path:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import disable_progress_bars
|
||||||
|
|
||||||
|
disable_progress_bars()
|
||||||
|
download_path = snapshot_download(
|
||||||
|
repo_id="ds4sd/docling-models",
|
||||||
|
force_download=force,
|
||||||
|
local_dir=local_dir,
|
||||||
|
revision="v2.1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
return Path(download_path)
|
||||||
|
|
||||||
def draw_table_and_cells(
|
def draw_table_and_cells(
|
||||||
self,
|
self,
|
||||||
conv_res: ConversionResult,
|
conv_res: ConversionResult,
|
||||||
|
@ -17,6 +17,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
TesseractCliOcrOptions,
|
TesseractCliOcrOptions,
|
||||||
TesseractOcrOptions,
|
TesseractOcrOptions,
|
||||||
)
|
)
|
||||||
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.base_ocr_model import BaseOcrModel
|
from docling.models.base_ocr_model import BaseOcrModel
|
||||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||||
from docling.models.document_picture_classifier import (
|
from docling.models.document_picture_classifier import (
|
||||||
@ -43,17 +44,16 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class StandardPdfPipeline(PaginatedPipeline):
|
class StandardPdfPipeline(PaginatedPipeline):
|
||||||
_layout_model_path = "model_artifacts/layout"
|
_layout_model_path = LayoutModel._model_path
|
||||||
_table_model_path = "model_artifacts/tableformer"
|
_table_model_path = TableStructureModel._model_path
|
||||||
|
|
||||||
def __init__(self, pipeline_options: PdfPipelineOptions):
|
def __init__(self, pipeline_options: PdfPipelineOptions):
|
||||||
super().__init__(pipeline_options)
|
super().__init__(pipeline_options)
|
||||||
self.pipeline_options: PdfPipelineOptions
|
self.pipeline_options: PdfPipelineOptions
|
||||||
|
|
||||||
if pipeline_options.artifacts_path is None:
|
artifacts_path: Optional[Path] = None
|
||||||
self.artifacts_path = self.download_models_hf()
|
if pipeline_options.artifacts_path is not None:
|
||||||
else:
|
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
|
||||||
self.artifacts_path = Path(pipeline_options.artifacts_path)
|
|
||||||
|
|
||||||
self.keep_images = (
|
self.keep_images = (
|
||||||
self.pipeline_options.generate_page_images
|
self.pipeline_options.generate_page_images
|
||||||
@ -79,15 +79,13 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
ocr_model,
|
ocr_model,
|
||||||
# Layout model
|
# Layout model
|
||||||
LayoutModel(
|
LayoutModel(
|
||||||
artifacts_path=self.artifacts_path
|
artifacts_path=artifacts_path,
|
||||||
/ StandardPdfPipeline._layout_model_path,
|
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
),
|
),
|
||||||
# Table structure model
|
# Table structure model
|
||||||
TableStructureModel(
|
TableStructureModel(
|
||||||
enabled=pipeline_options.do_table_structure,
|
enabled=pipeline_options.do_table_structure,
|
||||||
artifacts_path=self.artifacts_path
|
artifacts_path=artifacts_path,
|
||||||
/ StandardPdfPipeline._table_model_path,
|
|
||||||
options=pipeline_options.table_structure_options,
|
options=pipeline_options.table_structure_options,
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
),
|
),
|
||||||
@ -101,7 +99,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
CodeFormulaModel(
|
CodeFormulaModel(
|
||||||
enabled=pipeline_options.do_code_enrichment
|
enabled=pipeline_options.do_code_enrichment
|
||||||
or pipeline_options.do_formula_enrichment,
|
or pipeline_options.do_formula_enrichment,
|
||||||
artifacts_path=pipeline_options.artifacts_path,
|
artifacts_path=artifacts_path,
|
||||||
options=CodeFormulaModelOptions(
|
options=CodeFormulaModelOptions(
|
||||||
do_code_enrichment=pipeline_options.do_code_enrichment,
|
do_code_enrichment=pipeline_options.do_code_enrichment,
|
||||||
do_formula_enrichment=pipeline_options.do_formula_enrichment,
|
do_formula_enrichment=pipeline_options.do_formula_enrichment,
|
||||||
@ -111,7 +109,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
# Document Picture Classifier
|
# Document Picture Classifier
|
||||||
DocumentPictureClassifier(
|
DocumentPictureClassifier(
|
||||||
enabled=pipeline_options.do_picture_classification,
|
enabled=pipeline_options.do_picture_classification,
|
||||||
artifacts_path=pipeline_options.artifacts_path,
|
artifacts_path=artifacts_path,
|
||||||
options=DocumentPictureClassifierOptions(),
|
options=DocumentPictureClassifierOptions(),
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
),
|
),
|
||||||
@ -127,18 +125,29 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
def download_models_hf(
|
def download_models_hf(
|
||||||
local_dir: Optional[Path] = None, force: bool = False
|
local_dir: Optional[Path] = None, force: bool = False
|
||||||
) -> Path:
|
) -> Path:
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from huggingface_hub.utils import disable_progress_bars
|
|
||||||
|
|
||||||
disable_progress_bars()
|
if local_dir is None:
|
||||||
download_path = snapshot_download(
|
local_dir = settings.cache_dir / "models"
|
||||||
repo_id="ds4sd/docling-models",
|
|
||||||
force_download=force,
|
# Make sure the folder exists
|
||||||
local_dir=local_dir,
|
local_dir.mkdir(exist_ok=True, parents=True)
|
||||||
revision="v2.1.0",
|
|
||||||
|
# Download model weights
|
||||||
|
LayoutModel.download_models_hf(
|
||||||
|
local_dir=local_dir / LayoutModel._model_repo_folder, force=force
|
||||||
|
)
|
||||||
|
TableStructureModel.download_models_hf(
|
||||||
|
local_dir=local_dir / TableStructureModel._model_repo_folder, force=force
|
||||||
|
)
|
||||||
|
DocumentPictureClassifier.download_models_hf(
|
||||||
|
local_dir=local_dir / DocumentPictureClassifier._model_repo_folder,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
CodeFormulaModel.download_models_hf(
|
||||||
|
local_dir=local_dir / CodeFormulaModel._model_repo_folder, force=force
|
||||||
)
|
)
|
||||||
|
|
||||||
return Path(download_path)
|
return local_dir
|
||||||
|
|
||||||
def get_ocr_model(self) -> Optional[BaseOcrModel]:
|
def get_ocr_model(self) -> Optional[BaseOcrModel]:
|
||||||
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
|
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
|
||||||
|
Loading…
Reference in New Issue
Block a user