From dcf6fd6a413d4a581344689a46ad8781a47db343 Mon Sep 17 00:00:00 2001 From: Peter Staar Date: Wed, 9 Jul 2025 06:48:03 +0200 Subject: [PATCH] fixed the MyPy complaining Signed-off-by: Peter Staar --- docling/datamodel/asr_model_specs.py | 13 ++-- docling/datamodel/pipeline_options.py | 13 +++- docling/models/layout_model.py | 16 +++-- .../hf_transformers_model.py | 67 ++++++++++++++++--- docling/pipeline/vlm_pipeline.py | 10 +-- 5 files changed, 91 insertions(+), 28 deletions(-) diff --git a/docling/datamodel/asr_model_specs.py b/docling/datamodel/asr_model_specs.py index 426b5851..5527dd5b 100644 --- a/docling/datamodel/asr_model_specs.py +++ b/docling/datamodel/asr_model_specs.py @@ -11,12 +11,13 @@ from docling.datamodel.pipeline_options_asr_model import ( # ApiAsrOptions, InferenceAsrFramework, InlineAsrNativeWhisperOptions, + InlineAsrOptions, TransformersModelType, ) _log = logging.getLogger(__name__) -WHISPER_TINY = InlineAsrNativeWhisperOptions( +WHISPER_TINY: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="tiny", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -27,7 +28,7 @@ WHISPER_TINY = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_SMALL = InlineAsrNativeWhisperOptions( +WHISPER_SMALL: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="small", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -38,7 +39,7 @@ WHISPER_SMALL = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_MEDIUM = InlineAsrNativeWhisperOptions( +WHISPER_MEDIUM: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="medium", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -49,7 +50,7 @@ WHISPER_MEDIUM = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_BASE = InlineAsrNativeWhisperOptions( +WHISPER_BASE: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="base", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -60,7 +61,7 @@ WHISPER_BASE = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_LARGE = InlineAsrNativeWhisperOptions( +WHISPER_LARGE: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="large", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -71,7 +72,7 @@ WHISPER_LARGE = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_TURBO = InlineAsrNativeWhisperOptions( +WHISPER_TURBO: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="turbo", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 329d9de5..2b76a553 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -16,7 +16,15 @@ from docling.datamodel import asr_model_specs # Import the following for backwards compatibility from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions -from docling.datamodel.asr_model_specs import WHISPER_TINY as whisper_tiny +from docling.datamodel.asr_model_specs import ( + WHISPER_BASE, + WHISPER_LARGE, + WHISPER_MEDIUM, + WHISPER_SMALL, + WHISPER_TINY, + WHISPER_TINY as whisper_tiny, + WHISPER_TURBO, +) from docling.datamodel.layout_model_specs import ( LayoutModelConfig, docling_layout_egret_large, @@ -279,13 +287,12 @@ class VlmPipelineOptions(PaginatedPipelineOptions): class LayoutOptions(BaseModel): """Options for layout processing.""" - repo_id: str = "ds4sd/docling-layout-heron" create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells model_spec: LayoutModelConfig = docling_layout_v2 class AsrPipelineOptions(PipelineOptions): - asr_options: Union[InlineAsrOptions] = asr_model_specs.WHISPER_TINY + asr_options: Union[InlineAsrOptions] = WHISPER_TINY artifacts_path: Optional[Union[Path, str]] = None diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 2e72d957..6a668bf2 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -16,7 +16,7 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.layout_model_specs import LayoutModelConfig, docling_layout_v2 from docling.datamodel.pipeline_options import LayoutOptions from docling.datamodel.settings import settings -from docling.models.base_model import BasePageModel +from docling.models.base_model import BaseLayoutModel from docling.models.utils.hf_model_download import download_hf_model from docling.utils.accelerator_utils import decide_device from docling.utils.layout_postprocessor import LayoutPostprocessor @@ -26,7 +26,7 @@ from docling.utils.visualization import draw_clusters _log = logging.getLogger(__name__) -class LayoutModel(BasePageModel): +class LayoutModel(BaseLayoutModel): TEXT_ELEM_LABELS = [ DocItemLabel.TEXT, DocItemLabel.FOOTNOTE, @@ -179,7 +179,9 @@ class LayoutModel(BasePageModel): ) clusters.append(cluster) """ - predicted_clusters = self.predict_on_page(page_image=page_image) + predicted_clusters = self.predict_on_page_image( + page_image=page_image + ) if settings.debug.visualize_raw_layout: self.draw_clusters_and_cells_side_by_side( @@ -216,7 +218,9 @@ class LayoutModel(BasePageModel): ) """ page, processed_clusters, processed_cells = ( - self.postprocess_on_page(page=page, clusters=predicted_clusters) + self.postprocess_on_page_image( + page=page, clusters=predicted_clusters + ) ) with warnings.catch_warnings(): @@ -244,7 +248,7 @@ class LayoutModel(BasePageModel): yield page - def predict_on_page(self, *, page_image: Image.Image) -> list[Cluster]: + def predict_on_page_image(self, *, page_image: Image.Image) -> list[Cluster]: pred_items = self.layout_predictor.predict(page_image) clusters = [] @@ -263,7 +267,7 @@ class LayoutModel(BasePageModel): return clusters - def postprocess_on_page( + def postprocess_on_page_image( self, *, page: Page, clusters: list[Cluster] ) -> tuple[Page, list[Cluster], list[TextCell]]: processed_clusters, processed_cells = LayoutPostprocessor( diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 2c7b4b0a..4e892119 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -5,10 +5,12 @@ from collections.abc import Iterable from pathlib import Path from typing import Any, Optional +from PIL import Image + from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) -from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ( InlineVlmOptions, @@ -122,6 +124,43 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi # Load generation config self.generation_config = GenerationConfig.from_pretrained(artifacts_path) + def get_user_prompt(self, page: Optional[Page]) -> str: + # Define prompt structure + user_prompt = "" + if callable(self.vlm_options.prompt) and page is not None: + user_prompt = self.vlm_options.prompt(page.parsed_page) + elif isinstance(self.vlm_options.prompt, str): + user_prompt = self.vlm_options.prompt + + prompt = self.formulate_prompt(user_prompt) + return prompt + + def predict_on_page_image( + self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False + ) -> tuple[str, Optional[list[VlmPredictionToken]]]: + output = "" + + inputs = self.processor( + text=prompt, images=[page_image], return_tensors="pt" + ).to(self.device) + + # Call model to generate: + generated_ids = self.vlm_model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + use_cache=self.use_cache, + temperature=self.temperature, + generation_config=self.generation_config, + **self.vlm_options.extra_generation_config, + ) + + output = self.processor.batch_decode( + generated_ids[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=False, + )[0] + + return output, [] + def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -133,22 +172,29 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi with TimeRecorder(conv_res, "vlm"): assert page.size is not None - hi_res_image = page.get_image( + page_image = page.get_image( scale=self.vlm_options.scale, max_size=self.vlm_options.max_size ) + assert page_image is not None + # Define prompt structure + """ if callable(self.vlm_options.prompt): user_prompt = self.vlm_options.prompt(page.parsed_page) else: user_prompt = self.vlm_options.prompt prompt = self.formulate_prompt(user_prompt) - - inputs = self.processor( - text=prompt, images=[hi_res_image], return_tensors="pt" - ).to(self.device) + """ + prompt = self.get_user_prompt(page=page) start_time = time.time() + + """ + inputs = self.processor( + text=prompt, images=[page_image], return_tensors="pt" + ).to(self.device) + # Call model to generate: generated_ids = self.vlm_model.generate( **inputs, @@ -169,9 +215,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi _log.debug( f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." ) + """ + generated_text = self.predict_on_page_image( + page_image=page_image, prompt=prompt, output_tokens=False + ) + page.predictions.vlm_response = VlmPrediction( - text=generated_texts, - generation_time=generation_time, + text=generated_text, + generation_time=time.time() - start_time, ) yield page diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index aac61d8d..0ee06efb 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -115,7 +115,7 @@ class VlmPipeline(PaginatedPipeline): TwoStageVlmOptions, self.pipeline_options.vlm_options ) - layout_options = twostagevlm_options.lay_options + layout_options = twostagevlm_options.layout_options vlm_options = twostagevlm_options.vlm_options layout_model = LayoutModel( @@ -125,24 +125,24 @@ class VlmPipeline(PaginatedPipeline): ) if vlm_options.inference_framework == InferenceFramework.MLX: - vlm_model = HuggingFaceMlxModel( + vlm_model_mlx = HuggingFaceMlxModel( enabled=True, # must be always enabled for this pipeline to make sense. artifacts_path=artifacts_path, accelerator_options=pipeline_options.accelerator_options, vlm_options=vlm_options, ) self.build_pipe = [ - TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model) + TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model_mlx) ] elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS: - vlm_model = HuggingFaceTransformersVlmModel( + vlm_model_hf = HuggingFaceTransformersVlmModel( enabled=True, # must be always enabled for this pipeline to make sense. artifacts_path=artifacts_path, accelerator_options=pipeline_options.accelerator_options, vlm_options=vlm_options, ) self.build_pipe = [ - TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model) + TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model_hf) ] else: raise ValueError(