diff --git a/docling/cli/main.py b/docling/cli/main.py index 2177b788..a31ca274 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -66,6 +66,7 @@ from docling.datamodel.vlm_model_specs import ( GRANITE_VISION_TRANSFORMERS, GRANITEDOCLING_MLX, GRANITEDOCLING_TRANSFORMERS, + GRANITEDOCLING_VLLM, SMOLDOCLING_MLX, SMOLDOCLING_TRANSFORMERS, SMOLDOCLING_VLLM, @@ -686,6 +687,7 @@ def convert( # noqa: C901 "To run SmolDocling faster, please install mlx-vlm:\n" "pip install mlx-vlm" ) + elif vlm_model == VlmModelType.GRANITEDOCLING: pipeline_options.vlm_options = GRANITEDOCLING_TRANSFORMERS if sys.platform == "darwin": @@ -701,6 +703,9 @@ def convert( # noqa: C901 elif vlm_model == VlmModelType.SMOLDOCLING_VLLM: pipeline_options.vlm_options = SMOLDOCLING_VLLM + elif vlm_model == VlmModelType.GRANITEDOCLING_VLLM: + pipeline_options.vlm_options = GRANITEDOCLING_VLLM + pdf_format_option = PdfFormatOption( pipeline_cls=VlmPipeline, pipeline_options=pipeline_options ) diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index a9aeffbe..8502822e 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -53,6 +53,7 @@ class InlineVlmOptions(BaseVlmOptions): kind: Literal["inline_model_options"] = "inline_model_options" repo_id: str + revision: str = "main" trust_remote_code: bool = False load_in_8bit: bool = True llm_int8_threshold: float = 6.0 diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index 652e0afd..54d0c3e9 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -29,12 +29,20 @@ GRANITEDOCLING_TRANSFORMERS = InlineVlmOptions( AcceleratorDevice.CPU, AcceleratorDevice.CUDA, ], + extra_generation_config=dict(skip_special_tokens=False), scale=2.0, temperature=0.0, max_new_tokens=8192, stop_strings=["", "<|end_of_text|>"], ) +GRANITEDOCLING_VLLM = GRANITEDOCLING_TRANSFORMERS.model_copy() +GRANITEDOCLING_VLLM.inference_framework = InferenceFramework.VLLM +GRANITEDOCLING_VLLM.revision = ( + "untied" # change back to "main" with next vllm relase after 0.10.2 +) + + GRANITEDOCLING_MLX = InlineVlmOptions( repo_id="ibm-granite/granite-docling-258M-mlx", prompt="Convert this page to docling.", @@ -302,3 +310,4 @@ class VlmModelType(str, Enum): GRANITE_VISION_OLLAMA = "granite_vision_ollama" GOT_OCR_2 = "got_ocr_2" GRANITEDOCLING = "granite_docling" + GRANITEDOCLING_VLLM = "granite_docling_vllm" diff --git a/docling/models/base_model.py b/docling/models/base_model.py index bc78b78b..5d443c7b 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -88,7 +88,8 @@ class BaseVlmPageModel(BasePageModel, BaseVlmModel): if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW: return user_prompt - + elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE: + return "" elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct": _log.debug("Using specialized prompt for Phi-4") # Note: This might need adjustment for VLLM vs transformers diff --git a/docling/models/utils/hf_model_download.py b/docling/models/utils/hf_model_download.py index 3595166a..0ed5933c 100644 --- a/docling/models/utils/hf_model_download.py +++ b/docling/models/utils/hf_model_download.py @@ -34,7 +34,12 @@ class HuggingFaceModelDownloadMixin: local_dir: Optional[Path] = None, force: bool = False, progress: bool = False, + revision: Optional[str] = None, ) -> Path: return download_hf_model( - repo_id=repo_id, local_dir=local_dir, force=force, progress=progress + repo_id=repo_id, + local_dir=local_dir, + force=force, + progress=progress, + revision=revision, ) diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index e7157948..25eb9b88 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -75,7 +75,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload repo_cache_folder = vlm_options.repo_id.replace("/", "--") if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) + artifacts_path = self.download_models( + self.vlm_options.repo_id, revision=self.vlm_options.revision + ) elif (artifacts_path / repo_cache_folder).exists(): artifacts_path = artifacts_path / repo_cache_folder @@ -106,6 +108,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload self.processor = AutoProcessor.from_pretrained( artifacts_path, trust_remote_code=vlm_options.trust_remote_code, + revision=vlm_options.revision, ) self.processor.tokenizer.padding_side = "left" @@ -120,11 +123,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload else "sdpa" ), trust_remote_code=vlm_options.trust_remote_code, + revision=vlm_options.revision, ) self.vlm_model = torch.compile(self.vlm_model) # type: ignore # Load generation config - self.generation_config = GenerationConfig.from_pretrained(artifacts_path) + self.generation_config = GenerationConfig.from_pretrained( + artifacts_path, revision=vlm_options.revision + ) def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] @@ -196,7 +202,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload import torch from PIL import Image as PILImage - # -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor) + # -- Normalize images to RGB PIL pil_images: list[Image] = [] for img in image_batch: if isinstance(img, np.ndarray): @@ -258,13 +264,30 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload ] ) + # -- Filter out decoder-specific keys from extra_generation_config + decoder_keys = { + "skip_special_tokens", + "clean_up_tokenization_spaces", + "spaces_between_special_tokens", + } + generation_config = { + k: v + for k, v in self.vlm_options.extra_generation_config.items() + if k not in decoder_keys + } + decoder_config = { + k: v + for k, v in self.vlm_options.extra_generation_config.items() + if k in decoder_keys + } + # -- Generate (Image-Text-to-Text class expects these inputs from processor) gen_kwargs = { **inputs, "max_new_tokens": self.max_new_tokens, "use_cache": self.use_cache, "generation_config": self.generation_config, - **self.vlm_options.extra_generation_config, + **generation_config, } if self.temperature > 0: gen_kwargs["do_sample"] = True @@ -293,7 +316,8 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload ) decoded_texts: list[str] = decode_fn( - trimmed_sequences, skip_special_tokens=False + trimmed_sequences, + **decoder_config, ) # -- Clip off pad tokens from decoded texts diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 52e786f2..1ee588c7 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -60,6 +60,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): if artifacts_path is None: artifacts_path = self.download_models( self.vlm_options.repo_id, + revision=self.vlm_options.revision, ) elif (artifacts_path / repo_cache_folder).exists(): artifacts_path = artifacts_path / repo_cache_folder diff --git a/docling/models/vlm_models_inline/vllm_model.py b/docling/models/vlm_models_inline/vllm_model.py index b92b9236..8f019b96 100644 --- a/docling/models/vlm_models_inline/vllm_model.py +++ b/docling/models/vlm_models_inline/vllm_model.py @@ -7,9 +7,7 @@ from typing import Any, Dict, Optional, Union import numpy as np from PIL.Image import Image -from docling.datamodel.accelerator_options import ( - AcceleratorOptions, -) +from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ( @@ -17,9 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import ( TransformersPromptStyle, ) from docling.models.base_model import BaseVlmPageModel -from docling.models.utils.hf_model_download import ( - HuggingFaceModelDownloadMixin, -) +from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import TimeRecorder @@ -27,6 +23,62 @@ _log = logging.getLogger(__name__) class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): + """ + vLLM-backed vision-language model that accepts PIL images (or numpy arrays) + via vLLM's multi_modal_data, with prompt formatting handled by formulate_prompt(). + """ + + # --------- Allowlist of vLLM args --------- + # SamplingParams (runtime generation controls) + _VLLM_SAMPLING_KEYS = { + # Core + "max_tokens", + "temperature", + "top_p", + "top_k", + # Penalties + "presence_penalty", + "frequency_penalty", + "repetition_penalty", + # Stops / outputs + "stop", + "stop_token_ids", + "skip_special_tokens", + "spaces_between_special_tokens", + # Search / length + "n", + "best_of", + "length_penalty", + "early_stopping", + # Misc + "logprobs", + "prompt_logprobs", + "min_p", + "seed", + } + + # LLM(...) / EngineArgs (engine/load-time controls) + _VLLM_ENGINE_KEYS = { + # Model/tokenizer/impl + "tokenizer", + "tokenizer_mode", + "download_dir", + # Parallelism / memory / lengths + "tensor_parallel_size", + "pipeline_parallel_size", + "gpu_memory_utilization", + "max_model_len", + "max_num_batched_tokens", + "kv_cache_dtype", + "dtype", + # Quantization (coarse switch) + "quantization", + # Multimodal limits + "limit_mm_per_prompt", + # Execution toggles + "enforce_eager", + } + def __init__( self, enabled: bool, @@ -35,120 +87,147 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): vlm_options: InlineVlmOptions, ): self.enabled = enabled - self.vlm_options = vlm_options - if self.enabled: - from transformers import AutoProcessor - from vllm import LLM, SamplingParams + self.llm = None + self.sampling_params = None + self.processor = None # used for CHAT templating in formulate_prompt() + self.device = "cpu" + self.max_new_tokens = vlm_options.max_new_tokens + self.temperature = vlm_options.temperature - self.device = decide_device( - accelerator_options.device, - supported_devices=vlm_options.supported_devices, + if not self.enabled: + return + + from transformers import AutoProcessor + from vllm import LLM, SamplingParams + + # Device selection + self.device = decide_device( + accelerator_options.device, supported_devices=vlm_options.supported_devices + ) + _log.debug(f"Available device for VLM: {self.device}") + + # Resolve artifacts path / cache folder + repo_cache_folder = vlm_options.repo_id.replace("/", "--") + if artifacts_path is None: + artifacts_path = self.download_models( + self.vlm_options.repo_id, revision=self.vlm_options.revision ) - _log.debug(f"Available device for VLM: {self.device}") + elif (artifacts_path / repo_cache_folder).exists(): + artifacts_path = artifacts_path / repo_cache_folder - self.max_new_tokens = vlm_options.max_new_tokens - self.temperature = vlm_options.temperature + # --------- Strict split & validation of extra_generation_config --------- + extra_cfg = self.vlm_options.extra_generation_config - repo_cache_folder = vlm_options.repo_id.replace("/", "--") + load_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_ENGINE_KEYS} + gen_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_SAMPLING_KEYS} - if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) - elif (artifacts_path / repo_cache_folder).exists(): - artifacts_path = artifacts_path / repo_cache_folder - - # Initialize VLLM LLM - llm_kwargs: Dict[str, Any] = { - "model": str(artifacts_path), - "limit_mm_per_prompt": {"image": 1}, - "trust_remote_code": vlm_options.trust_remote_code, - "model_impl": "transformers", - "gpu_memory_utilization": 0.3, # hardcoded for now, leaves room for ~3 different models. - } - - # Add device-specific configurations - - if self.device == "cpu": - llm_kwargs["device"] = "cpu" - - # Add quantization if specified - if vlm_options.quantized: - if vlm_options.load_in_8bit: - llm_kwargs["quantization"] = "bitsandbytes" - - self.llm = LLM(**llm_kwargs) - - # Initialize processor for prompt formatting - self.processor = AutoProcessor.from_pretrained( - artifacts_path, - trust_remote_code=vlm_options.trust_remote_code, + unknown = sorted( + k + for k in extra_cfg.keys() + if k not in self._VLLM_ENGINE_KEYS and k not in self._VLLM_SAMPLING_KEYS + ) + if unknown: + _log.warning( + "Ignoring unknown extra_generation_config keys for vLLM: %s", unknown ) - # Set up sampling parameters - self.sampling_params = SamplingParams( - temperature=self.temperature, - max_tokens=self.max_new_tokens, - stop=vlm_options.stop_strings if vlm_options.stop_strings else None, - **vlm_options.extra_generation_config, - ) + # --------- Construct LLM kwargs (engine/load-time) --------- + llm_kwargs: Dict[str, Any] = { + "model": str(artifacts_path), + "model_impl": "transformers", + "limit_mm_per_prompt": {"image": 1}, + "revision": self.vlm_options.revision, + "trust_remote_code": self.vlm_options.trust_remote_code, + **load_cfg, + } + + if self.device == "cpu": + llm_kwargs.setdefault("enforce_eager", True) + else: + llm_kwargs.setdefault( + "gpu_memory_utilization", 0.3 + ) # room for other models + + # Quantization (kept as-is; coarse) + if self.vlm_options.quantized and self.vlm_options.load_in_8bit: + llm_kwargs.setdefault("quantization", "bitsandbytes") + + # Initialize vLLM LLM + self.llm = LLM(**llm_kwargs) + + # Initialize processor for prompt templating (needed for CHAT style) + self.processor = AutoProcessor.from_pretrained( + artifacts_path, + trust_remote_code=self.vlm_options.trust_remote_code, + revision=self.vlm_options.revision, + ) + + # --------- SamplingParams (runtime) --------- + self.sampling_params = SamplingParams( + temperature=self.temperature, + max_tokens=self.max_new_tokens, + stop=(self.vlm_options.stop_strings or None), + **gen_cfg, + ) def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: + # If disabled, pass-through + if not self.enabled: + for page in page_batch: + yield page + return + page_list = list(page_batch) if not page_list: return - valid_pages = [] - invalid_pages = [] + # Preserve original order + original_order = page_list[:] + # Separate valid/invalid + valid_pages: list[Page] = [] + invalid_pages: list[Page] = [] for page in page_list: assert page._backend is not None - if not page._backend.is_valid(): - invalid_pages.append(page) - else: + if page._backend.is_valid(): valid_pages.append(page) + else: + invalid_pages.append(page) - # Process valid pages in batch if valid_pages: with TimeRecorder(conv_res, "vlm"): - # Prepare images and prompts for batch processing - images = [] - user_prompts = [] - pages_with_images = [] + images: list[Image] = [] + user_prompts: list[str] = [] + pages_with_images: list[Page] = [] for page in valid_pages: assert page.size is not None hi_res_image = page.get_image( - scale=self.vlm_options.scale, max_size=self.vlm_options.max_size + scale=self.vlm_options.scale, + max_size=self.vlm_options.max_size, ) + if hi_res_image is None: + continue - # Only process pages with valid images - if hi_res_image is not None: - images.append(hi_res_image) + images.append(hi_res_image) - # Define prompt structure - if callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt(page.parsed_page) - else: - user_prompt = self.vlm_options.prompt + # Define prompt structure + user_prompt = self.vlm_options.build_prompt(page.parsed_page) - user_prompts.append(user_prompt) - pages_with_images.append(page) + user_prompts.append(user_prompt) + pages_with_images.append(page) - # Use process_images for the actual inference - if images: # Only if we have valid images + if images: predictions = list(self.process_images(images, user_prompts)) - - # Attach results to pages for page, prediction in zip(pages_with_images, predictions): page.predictions.vlm_response = prediction - # Yield all pages (valid and invalid) - for page in invalid_pages: - yield page - for page in valid_pages: + # Yield in original order + for page in original_order: yield page def process_images( @@ -156,50 +235,33 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): image_batch: Iterable[Union[Image, np.ndarray]], prompt: Union[str, list[str]], ) -> Iterable[VlmPrediction]: - """Process raw images without page metadata in a single batched inference call. + """Process images in a single batched vLLM inference call.""" + import numpy as np + from PIL import Image as PILImage - Args: - image_batch: Iterable of PIL Images or numpy arrays - prompt: Either: - - str: Single prompt used for all images - - list[str]: List of prompts (one per image, must match image count) - - Raises: - ValueError: If prompt list length doesn't match image count. - """ + # -- Normalize images to RGB PIL pil_images: list[Image] = [] - for img in image_batch: - # Convert numpy array to PIL Image if needed if isinstance(img, np.ndarray): - if img.ndim == 3 and img.shape[2] in [3, 4]: - from PIL import Image as PILImage - + if img.ndim == 3 and img.shape[2] in (3, 4): pil_img = PILImage.fromarray(img.astype(np.uint8)) elif img.ndim == 2: - from PIL import Image as PILImage - pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L") else: raise ValueError(f"Unsupported numpy array shape: {img.shape}") else: pil_img = img - - # Ensure image is in RGB mode (handles RGBA, L, etc.) if pil_img.mode != "RGB": pil_img = pil_img.convert("RGB") - pil_images.append(pil_img) - if len(pil_images) == 0: + if not pil_images: return - # Handle prompt parameter + # Normalize prompts if isinstance(prompt, str): - # Single prompt for all images user_prompts = [prompt] * len(pil_images) elif isinstance(prompt, list): - # List of prompts (one per image) if len(prompt) != len(pil_images): raise ValueError( f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})" @@ -208,28 +270,31 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): else: raise ValueError(f"prompt must be str or list[str], got {type(prompt)}") - # Format prompts individually - prompts: list[str] = [ - self.formulate_prompt(user_prompt) for user_prompt in user_prompts + # Format prompts + prompts: list[str] = [self.formulate_prompt(up) for up in user_prompts] + + # Build vLLM inputs + llm_inputs = [ + {"prompt": p, "multi_modal_data": {"image": im}} + for p, im in zip(prompts, pil_images) ] - # Prepare VLLM inputs - llm_inputs = [] - for prompt, image in zip(prompts, pil_images): - llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - + # Generate + assert self.llm is not None and self.sampling_params is not None start_time = time.time() outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore generation_time = time.time() - start_time - # Logging tokens count for the first sample as a representative metric - if len(outputs) > 0: - num_tokens = len(outputs[0].outputs[0].token_ids) - _log.debug( - f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." - ) + # Optional debug + if outputs: + try: + num_tokens = len(outputs[0].outputs[0].token_ids) + _log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.") + except Exception: + pass + # Emit predictions for output in outputs: - # Apply decode_response to the output text - decoded_text = self.vlm_options.decode_response(output.outputs[0].text) + text = output.outputs[0].text if output.outputs else "" + decoded_text = self.vlm_options.decode_response(text) yield VlmPrediction(text=decoded_text, generation_time=generation_time)