From b2d5c783ae115c469ecc07c28a7faca9f482bb49 Mon Sep 17 00:00:00 2001 From: Peter Staar Date: Thu, 10 Jul 2025 15:38:15 +0200 Subject: [PATCH] working two-stage vlm approach from the cli Signed-off-by: Peter Staar --- docling/cli/main.py | 7 +++++++ docling/datamodel/vlm_model_specs.py | 2 +- .../vlm_models_inline/hf_transformers_model.py | 4 +++- docling/models/vlm_models_inline/mlx_model.py | 5 ++++- .../models/vlm_models_inline/two_stage_vlm_model.py | 13 ++++++++----- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/docling/cli/main.py b/docling/cli/main.py index ae275ea9..1b623b0d 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -63,6 +63,7 @@ from docling.datamodel.vlm_model_specs import ( GRANITE_VISION_TRANSFORMERS, SMOLDOCLING_MLX, SMOLDOCLING_TRANSFORMERS, + VLM2STAGE, VlmModelType, ) from docling.document_converter import ( @@ -627,6 +628,12 @@ def convert( # noqa: C901 "To run SmolDocling faster, please install mlx-vlm:\n" "pip install mlx-vlm" ) + elif vlm_model == VlmModelType.VLM2STAGE: + pipeline_options.vlm_options = VLM2STAGE + else: + raise ValueError( + f"{vlm_model} is not of type GRANITE_VISION, GRANITE_VISION_OLLAMA, SMOLDOCLING_TRANSFORMERS or VLM2STAGE" + ) pdf_format_option = PdfFormatOption( pipeline_cls=VlmPipeline, pipeline_options=pipeline_options diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index 8025d02f..906e4e9c 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -153,4 +153,4 @@ class VlmModelType(str, Enum): SMOLDOCLING = "smoldocling" GRANITE_VISION = "granite_vision" GRANITE_VISION_OLLAMA = "granite_vision_ollama" - VLM2STAGE = "docling2stage" + VLM2STAGE = "vlm2stage" diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 4e892119..5434ee50 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -40,7 +40,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi self.vlm_options = vlm_options self.scale = self.vlm_options.scale - # self.max_size = self.vlm_options.max_size + self.max_size = 512 + if isinstance(self.vlm_options.max_size, int): + self.max_size = self.vlm_options.max_size if self.enabled: import torch diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index c28abe41..fa28de7f 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -35,7 +35,10 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): self.max_tokens = vlm_options.max_new_tokens self.temperature = vlm_options.temperature self.scale = self.vlm_options.scale - # self.max_size = self.vlm_options.max_size + + self.max_size = 512 + if isinstance(self.vlm_options.max_size, int): + self.max_size = self.vlm_options.max_size if self.enabled: try: diff --git a/docling/models/vlm_models_inline/two_stage_vlm_model.py b/docling/models/vlm_models_inline/two_stage_vlm_model.py index 846fe991..2ef18692 100644 --- a/docling/models/vlm_models_inline/two_stage_vlm_model.py +++ b/docling/models/vlm_models_inline/two_stage_vlm_model.py @@ -50,12 +50,12 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): page_image = page.get_image( scale=self.vlm_model.scale, max_size=self.vlm_model.max_size ) - assert page_image is not None pred_clusters = self.layout_model.predict_on_page_image( page_image=page_image ) + page, processed_clusters, processed_cells = ( self.layout_model.postprocess_on_page_image( page=page, clusters=pred_clusters @@ -68,14 +68,17 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): ) start_time = time.time() - generated_text = self.vlm_model.predict_on_page_image( - page_image=page_image, prompt=prompt + generated_text, generated_tokens = ( + self.vlm_model.predict_on_page_image( + page_image=page_image, prompt=prompt + ) ) page.predictions.vlm_response = VlmPrediction( - text=generated_text, generation_time=time.time() - start_time + text=generated_text, + generation_time=time.time() - start_time, + generated_tokens=generated_tokens, ) - yield page def formulate_prompt(self, *, user_prompt: str, clusters: list[Cluster]) -> str: