From f42676aab91bcd1b06e256eb16cdc926d11d524c Mon Sep 17 00:00:00 2001 From: Christoph Auer Date: Fri, 15 Aug 2025 17:56:14 +0200 Subject: [PATCH] Implement proper batch inference for HuggingFaceTransformersVlmModel Signed-off-by: Christoph Auer --- docling/datamodel/vlm_model_specs.py | 49 +++++++ docling/models/readingorder_model.py | 2 +- .../hf_transformers_model.py | 135 ++++++++++-------- docling/models/vlm_models_inline/mlx_model.py | 70 ++++++--- .../threaded_multistage_vlm_pipeline.py | 28 ++-- docling/pipeline/vlm_pipeline.py | 4 +- 6 files changed, 193 insertions(+), 95 deletions(-) diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index 0abb1642..d09e0a81 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -27,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions( supported_devices=[AcceleratorDevice.MPS], scale=2.0, temperature=0.0, + stop_strings=["", ""], ) SMOLDOCLING_TRANSFORMERS = InlineVlmOptions( @@ -42,6 +43,7 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions( ], scale=2.0, temperature=0.0, + stop_strings=["", ""], ) SMOLDOCLING_VLLM = InlineVlmOptions( @@ -56,8 +58,55 @@ SMOLDOCLING_VLLM = InlineVlmOptions( ], scale=2.0, temperature=0.0, + stop_strings=["", ""], ) +# SmolVLM-500-Instruct +SMOLVLM500_TRANSFORMERS = InlineVlmOptions( + repo_id="HuggingFaceTB/SmolVLM-500M-Instruct", + prompt="Transcribe this image to plain text.", + response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.TRANSFORMERS, + transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + ], + scale=2.0, + temperature=0.0, +) + +# SmolVLM-500-Instruct +SMOLVLM500_MLX = InlineVlmOptions( + repo_id="moot20/SmolVLM-500M-Instruct-MLX", + prompt="Transcribe this image to plain text.", + response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.MLX, + transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT, + supported_devices=[ + AcceleratorDevice.MPS, + ], + scale=2.0, + temperature=0.0, +) + +SMOLVLM500_VLLM = InlineVlmOptions( + repo_id="HuggingFaceTB/SmolVLM-500M-Instruct", + prompt="Transcribe this image to plain text.", + response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.VLLM, + transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + ], + scale=2.0, + temperature=0.0, +) + + # GraniteVision GRANITE_VISION_TRANSFORMERS = InlineVlmOptions( repo_id="ibm-granite/granite-vision-3.2-2b", diff --git a/docling/models/readingorder_model.py b/docling/models/readingorder_model.py index 375ad4e4..b98b8749 100644 --- a/docling/models/readingorder_model.py +++ b/docling/models/readingorder_model.py @@ -316,7 +316,7 @@ class ReadingOrderModel: current_list = None new_item = out_doc.add_text( - label=DocItemLabel.FORMULA, text="", orig=cap_text, prov=prov + label=DocItemLabel.FORMULA, text=cap_text, orig=cap_text, prov=prov ) else: current_list = None diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index e81eafc3..9d95e9c6 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Union import numpy as np from PIL.Image import Image +from transformers import StoppingCriteriaList, StopStringCriteria from docling.datamodel.accelerator_options import ( AcceleratorOptions, @@ -227,109 +228,119 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload image_batch: Iterable[Union[Image, np.ndarray]], prompt: Union[str, list[str]], ) -> Iterable[VlmPrediction]: - """Process raw images without page metadata in a single batched inference call. - - Args: - image_batch: Iterable of PIL Images or numpy arrays - prompt: Either: - - str: Single prompt used for all images - - list[str]: List of prompts (one per image, must match image count) - - Raises: - ValueError: If prompt list length doesn't match image count. """ + Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM). + - Lets the processor handle all padding & batching for text+images. + - Trims generated sequences per row using attention_mask (no pad-id fallbacks). + - Keeps your formulate_prompt() exactly as-is. + """ + import numpy as np + import torch + from PIL import Image as PILImage + + # -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor) pil_images: list[Image] = [] - for img in image_batch: - # Convert numpy array to PIL Image if needed if isinstance(img, np.ndarray): - if img.ndim == 3 and img.shape[2] in [3, 4]: - from PIL import Image as PILImage - + if img.ndim == 3 and img.shape[2] in (3, 4): pil_img = PILImage.fromarray(img.astype(np.uint8)) elif img.ndim == 2: - from PIL import Image as PILImage - pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L") else: raise ValueError(f"Unsupported numpy array shape: {img.shape}") else: pil_img = img - - # Ensure image is in RGB mode (handles RGBA, L, etc.) if pil_img.mode != "RGB": pil_img = pil_img.convert("RGB") - pil_images.append(pil_img) - if len(pil_images) == 0: + if not pil_images: return - # Handle prompt parameter + # -- Normalize prompts (1 per image) if isinstance(prompt, str): - # Single prompt for all images user_prompts = [prompt] * len(pil_images) - elif isinstance(prompt, list): - # List of prompts (one per image) + else: if len(prompt) != len(pil_images): raise ValueError( f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})" ) user_prompts = prompt - else: - raise ValueError(f"prompt must be str or list[str], got {type(prompt)}") - # Format prompts individually - prompts: list[str] = [ - self.formulate_prompt(user_prompt) for user_prompt in user_prompts - ] + # Use your prompt formatter verbatim + prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts] + # -- Processor performs BOTH text+image preprocessing + batch padding (recommended) inputs = self.processor( - text=prompts, images=pil_images, return_tensors="pt", padding=True - ).to(self.device) + text=prompts, + images=pil_images, + return_tensors="pt", + padding=True, # pad across batch for both text and vision + # no truncation by default; match SmolDocling examples + ) + inputs = { + k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items() + } + + # -- Optional stopping criteria + stopping_criteria = None + if self.vlm_options.stop_strings: + stopping_criteria = StoppingCriteriaList( + [ + StopStringCriteria( + stop_strings=self.vlm_options.stop_strings, + tokenizer=self.processor.tokenizer, + ) + ] + ) + + # -- Generate (Image-Text-to-Text class expects these inputs from processor) + gen_kwargs = { + **inputs, + "max_new_tokens": self.max_new_tokens, + "use_cache": self.use_cache, + "generation_config": self.generation_config, + "temperature": self.temperature, + **self.vlm_options.extra_generation_config, + } + if stopping_criteria is not None: + gen_kwargs["stopping_criteria"] = stopping_criteria start_time = time.time() - generated_ids = self.vlm_model.generate( - **inputs, - max_new_tokens=self.max_new_tokens, - use_cache=self.use_cache, - # temperature=self.temperature, - generation_config=self.generation_config, - **self.vlm_options.extra_generation_config, - ) + with torch.no_grad(): + generated_ids = self.vlm_model.generate(**gen_kwargs) generation_time = time.time() - start_time - # Determine per-sample prompt lengths - try: - attention_mask = inputs["attention_mask"] - input_lengths: list[int] = attention_mask.sum(dim=1).tolist() - except KeyError: - tokenizer = ( - self.processor.tokenizer - ) # Expect tokenizer to be present when text is provided - pad_token_id = tokenizer.pad_token_id - if pad_token_id is not None: - input_lengths = ( - (inputs["input_ids"] != pad_token_id).sum(dim=1).tolist() - ) - else: - # Fallback: assume uniform prompt length (least accurate but preserves execution) - uniform_len = int(inputs["input_ids"].shape[1]) - input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0]) + # -- Trim per sample using attention_mask (robust for batched prompts) + if "attention_mask" not in inputs: + raise RuntimeError( + "Processor did not return 'attention_mask'. Ensure padding=True and text tokenization are enabled." + ) + input_lengths = inputs["attention_mask"].sum(dim=1).tolist() trimmed_sequences: list[list[int]] = [ generated_ids[i, int(input_lengths[i]) :].tolist() for i in range(generated_ids.shape[0]) ] - decoded_texts: list[str] = self.processor.batch_decode( + + # -- Decode with the processor/tokenizer (skip specials, keep DocTags as text) + decode_fn = getattr(self.processor, "batch_decode", None) + if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None: + decode_fn = self.processor.tokenizer.batch_decode + if decode_fn is None: + raise RuntimeError( + "Neither processor.batch_decode nor tokenizer.batch_decode is available." + ) + + decoded_texts: list[str] = decode_fn( trimmed_sequences, skip_special_tokens=True ) - # Logging tokens count for the first sample as a representative metric + # -- Optional logging if generated_ids.shape[0] > 0: - num_tokens = int(generated_ids[0].shape[0]) _log.debug( - f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." + f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " + f"for batch size {generated_ids.shape[0]}." ) for text in decoded_texts: diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 60894869..7a3b2410 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -142,8 +142,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): Raises: ValueError: If prompt list length doesn't match image count. """ - from mlx_vlm import generate - # Convert image batch to list for length validation image_list = list(image_batch) @@ -194,33 +192,67 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): self.processor, self.config, user_prompt, num_images=1 ) - # Generate text from the image - MLX can accept PIL Images directly despite type annotations + # Stream generate with stop strings support start_time = time.time() - generated_result = generate( + _log.debug("start generating ...") + + tokens: list[VlmPredictionToken] = [] + output = "" + + # Use stream_generate for proper stop string handling + for token in self.stream_generate( self.vlm_model, self.processor, formatted_prompt, - image=image, # Pass PIL Image directly - much more efficient than disk I/O + [image], # MLX stream_generate expects list of images + max_tokens=self.max_tokens, verbose=False, temp=self.temperature, - max_tokens=self.max_tokens, - ) + ): + # Collect token information + if len(token.logprobs.shape) == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[token.token], + ) + ) + elif ( + len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1 + ): + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[0, token.token], + ) + ) + else: + _log.warning( + f"incompatible shape for logprobs: {token.logprobs.shape}" + ) + + output += token.text + + # Check for any configured stop strings + if self.vlm_options.stop_strings: + if any( + stop_str in output + for stop_str in self.vlm_options.stop_strings + ): + _log.debug("Stopping generation due to stop string match") + break + generation_time = time.time() - start_time - # MLX generate returns a tuple (text, info_dict), extract just the text - if isinstance(generated_result, tuple): - generated_text = generated_result[0] - _log.debug( - f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}" - ) - else: - generated_text = generated_result + _log.debug( + f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)." + ) - _log.debug(f"Generated text in {generation_time:.2f}s.") yield VlmPrediction( - text=generated_text, + text=output, generation_time=generation_time, - # MLX generate doesn't expose tokens directly, so we leave it empty - generated_tokens=[], + generated_tokens=tokens, ) _log.debug("MLX model: Released global lock") diff --git a/docling/pipeline/threaded_multistage_vlm_pipeline.py b/docling/pipeline/threaded_multistage_vlm_pipeline.py index 9921af87..a9dd9e64 100644 --- a/docling/pipeline/threaded_multistage_vlm_pipeline.py +++ b/docling/pipeline/threaded_multistage_vlm_pipeline.py @@ -57,6 +57,8 @@ from docling.datamodel.vlm_model_specs import ( DOLPHIN_TRANSFORMERS, SMOLDOCLING_MLX, SMOLDOCLING_TRANSFORMERS, + SMOLVLM500_MLX, + SMOLVLM500_TRANSFORMERS, ) from docling.models.layout_model import LayoutModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions @@ -159,21 +161,23 @@ class ThreadedMultiStageVlmPipelineOptions(PaginatedPipelineOptions): # text_opts = DOLPHIN_TRANSFORMERS.model_copy() # text_opts.prompt = "Read text in the image. " - base_model = SMOLDOCLING_TRANSFORMERS - - formula_opts = base_model.model_copy() - formula_opts.prompt = "Convert formula to latex." - formula_opts.response_format = ResponseFormat.OTSL - - code_opts = base_model.model_copy() - code_opts.prompt = "Convert code to text." - code_opts.response_format = ResponseFormat.OTSL + base_model = SMOLVLM500_TRANSFORMERS text_opts = base_model.model_copy() - text_opts.prompt = "Convert this page to docling." - text_opts.response_format = ResponseFormat.OTSL + # text_opts.prompt = "Convert this page to docling." + text_opts.prompt = "What does this say?" + text_opts.response_format = ResponseFormat.PLAINTEXT - table_opts = base_model.model_copy() + formula_opts = base_model.model_copy() + # formula_opts.prompt = "Convert formula to latex." + formula_opts.prompt = "What does this say?" + formula_opts.response_format = ResponseFormat.PLAINTEXT + + code_opts = SMOLDOCLING_TRANSFORMERS.model_copy() + code_opts.prompt = "Convert code to text." + code_opts.response_format = ResponseFormat.DOCTAGS + + table_opts = SMOLDOCLING_TRANSFORMERS.model_copy() table_opts.prompt = "Convert this table to OTSL." table_opts.response_format = ResponseFormat.OTSL diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index 5a05b9a0..d69f2485 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -128,7 +128,9 @@ class VlmPipeline(PaginatedPipeline): page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore if page._backend is not None and page._backend.is_valid(): page.size = page._backend.get_size() - page.parsed_page = page._backend.get_segmented_page() + + if self.force_backend_text: + page.parsed_page = page._backend.get_segmented_page() return page