diff --git a/docling/models/ds_glm_model.py b/docling/models/ds_glm_model.py index 529b12ce..5fa35af1 100644 --- a/docling/models/ds_glm_model.py +++ b/docling/models/ds_glm_model.py @@ -14,23 +14,24 @@ from docling_core.types import Ref from docling_core.types.experimental import BoundingBox, CoordOrigin from docling_core.types.experimental.document import DoclingDocument from PIL import ImageDraw +from pydantic import BaseModel from docling.datamodel.base_models import Cluster from docling.datamodel.document import ConversionResult -class GlmModel: - def __init__(self, config): - self.config = config - self.create_legacy_output = config.get("create_legacy_output", True) +class GlmOptions(BaseModel): + create_legacy_output: bool = True + model_names: str = "" # e.g. "language;term;reference" + + +class GlmModel: + def __init__(self, options: GlmOptions): + self.options = options + self.create_legacy_output = self.options.create_legacy_output - self.model_names = self.config.get( - "model_names", "" - ) # "language;term;reference" load_pretrained_nlp_models() - # model = init_nlp_model(model_names="language;term;reference") - model = init_nlp_model(model_names=self.model_names) - self.model = model + self.model = init_nlp_model(model_names=self.options.model_names) def __call__( self, conv_res: ConversionResult diff --git a/docling/pipeline/standard_pdf_model_pipeline.py b/docling/pipeline/standard_pdf_model_pipeline.py index 3ec4c17e..cba8609b 100644 --- a/docling/pipeline/standard_pdf_model_pipeline.py +++ b/docling/pipeline/standard_pdf_model_pipeline.py @@ -13,7 +13,7 @@ from docling.datamodel.pipeline_options import ( TesseractOcrOptions, ) from docling.models.base_ocr_model import BaseOcrModel -from docling.models.ds_glm_model import GlmModel +from docling.models.ds_glm_model import GlmModel, GlmOptions from docling.models.easyocr_model import EasyOcrModel from docling.models.layout_model import LayoutModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions @@ -42,7 +42,9 @@ class StandardPdfModelPipeline(PaginatedModelPipeline): self.artifacts_path = Path(artifacts_path) self.glm_model = GlmModel( - config={"create_legacy_output": pipeline_options.create_legacy_output} + options=GlmOptions( + create_legacy_output=pipeline_options.create_legacy_output + ) ) if ocr_model := self.get_ocr_model() is None: