diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index d4510318..c1ec28aa 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -20,9 +20,13 @@ class ResponseFormat(str, Enum): class InferenceFramework(str, Enum): MLX = "mlx" - TRANSFORMERS = "transformers" # TODO: how to flag this as outdated? - TRANSFORMERS_VISION2SEQ = "transformers-vision2seq" - TRANSFORMERS_CAUSALLM = "transformers-causallm" + TRANSFORMERS = "transformers" + + +class TransformersModelType(str, Enum): + AUTOMODEL = "automodel" + AUTOMODEL_VISION2SEQ = "automodel-vision2seq" + AUTOMODEL_CAUSALLM = "automodel-causallm" class InlineVlmOptions(BaseVlmOptions): @@ -35,6 +39,7 @@ class InlineVlmOptions(BaseVlmOptions): quantized: bool = False inference_framework: InferenceFramework + transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL response_format: ResponseFormat supported_devices: List[AcceleratorDevice] = [ diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index 85ad1b40..5045c846 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -11,6 +11,7 @@ from docling.datamodel.pipeline_options_vlm_model import ( InferenceFramework, InlineVlmOptions, ResponseFormat, + TransformersModelType, ) _log = logging.getLogger(__name__) @@ -31,7 +32,8 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions( repo_id="ds4sd/SmolDocling-256M-preview", prompt="Convert this page to docling.", response_format=ResponseFormat.DOCTAGS, - inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ, + inference_framework=InferenceFramework.TRANSFORMERS, + transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ, supported_devices=[ AcceleratorDevice.CPU, AcceleratorDevice.CUDA, @@ -46,7 +48,8 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions( repo_id="ibm-granite/granite-vision-3.2-2b", prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!", response_format=ResponseFormat.MARKDOWN, - inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ, + inference_framework=InferenceFramework.TRANSFORMERS, + transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ, supported_devices=[ AcceleratorDevice.CPU, AcceleratorDevice.CUDA, @@ -71,7 +74,8 @@ PIXTRAL_12B_TRANSFORMERS = InlineVlmOptions( repo_id="mistral-community/pixtral-12b", prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!", response_format=ResponseFormat.MARKDOWN, - inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ, + inference_framework=InferenceFramework.TRANSFORMERS, + transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ, supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA], scale=2.0, temperature=0.0, @@ -93,7 +97,8 @@ PHI4_TRANSFORMERS = InlineVlmOptions( prompt="Convert this page to MarkDown. Do not miss any text and only output the bare markdown", trust_remote_code=True, response_format=ResponseFormat.MARKDOWN, - inference_framework=InferenceFramework.TRANSFORMERS_CAUSALLM, + inference_framework=InferenceFramework.TRANSFORMERS, + transformers_model_type=TransformersModelType.AUTOMODEL_CAUSALLM, supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA], scale=2.0, temperature=0.0, diff --git a/docling/models/vlm_models_inline/hf_transformers_causallm_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py similarity index 79% rename from docling/models/vlm_models_inline/hf_transformers_causallm_model.py rename to docling/models/vlm_models_inline/hf_transformers_model.py index aef23b79..bc02fbc6 100644 --- a/docling/models/vlm_models_inline/hf_transformers_causallm_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -3,14 +3,17 @@ import logging import time from collections.abc import Iterable from pathlib import Path -from typing import Optional +from typing import Any, Optional from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions +from docling.datamodel.pipeline_options_vlm_model import ( + InlineVlmOptions, + TransformersModelType, +) from docling.models.base_model import BasePageModel from docling.models.utils.hf_model_download import ( HuggingFaceModelDownloadMixin, @@ -21,9 +24,7 @@ from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) -class HuggingFaceVlmModel_AutoModelForCausalLM( - BasePageModel, HuggingFaceModelDownloadMixin -): +class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): def __init__( self, enabled: bool, @@ -37,8 +38,10 @@ class HuggingFaceVlmModel_AutoModelForCausalLM( if self.enabled: import torch - from transformers import ( # type: ignore + from transformers import ( + AutoModel, AutoModelForCausalLM, + AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, GenerationConfig, @@ -77,15 +80,26 @@ class HuggingFaceVlmModel_AutoModelForCausalLM( llm_int8_threshold=vlm_options.llm_int8_threshold, ) + model_cls: Any = AutoModel + if ( + self.vlm_options.transformers_model_type + == TransformersModelType.AUTOMODEL_CAUSALLM + ): + model_cls = AutoModelForCausalLM + elif ( + self.vlm_options.transformers_model_type + == TransformersModelType.AUTOMODEL_VISION2SEQ + ): + model_cls = AutoModelForVision2Seq + self.processor = AutoProcessor.from_pretrained( artifacts_path, trust_remote_code=vlm_options.trust_remote_code, ) - self.vlm_model = AutoModelForCausalLM.from_pretrained( + self.vlm_model = model_cls.from_pretrained( artifacts_path, device_map=self.device, torch_dtype="auto", - quantization_config=self.param_quantization_config, _attn_implementation=( "flash_attention_2" if self.device.startswith("cuda") @@ -109,51 +123,46 @@ class HuggingFaceVlmModel_AutoModelForCausalLM( with TimeRecorder(conv_res, "vlm"): assert page.size is not None - hi_res_image = page.get_image(scale=2) # self.vlm_options.scale) - - if hi_res_image is not None: - im_width, im_height = hi_res_image.size + hi_res_image = page.get_image(scale=self.vlm_options.scale) # Define prompt structure prompt = self.formulate_prompt() - print(f"prompt: '{prompt}', size: {im_width}, {im_height}") inputs = self.processor( - text=prompt, images=hi_res_image, return_tensors="pt" + text=prompt, images=[hi_res_image], return_tensors="pt" ).to(self.device) - # Generate response start_time = time.time() - generate_ids = self.vlm_model.generate( + # Call model to generate: + generated_ids = self.vlm_model.generate( **inputs, max_new_tokens=self.max_new_tokens, - use_cache=self.use_cache, # Enables KV caching which can improve performance + use_cache=self.use_cache, temperature=self.temperature, generation_config=self.generation_config, **self.vlm_options.extra_generation_config, ) - generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :] - num_tokens = len(generate_ids[0]) generation_time = time.time() - start_time - - response = self.processor.batch_decode( - generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False, + generated_texts = self.processor.batch_decode( + generated_ids[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=False, )[0] + num_tokens = len(generated_ids[0]) _log.debug( f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." ) page.predictions.vlm_response = VlmPrediction( - text=response, generation_time=generation_time + text=generated_texts, + generation_time=generation_time, ) yield page def formulate_prompt(self) -> str: """Formulate a prompt for the VLM.""" + if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct": _log.debug("Using specialized prompt for Phi-4") # more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally @@ -167,7 +176,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM( return prompt - _log.debug("Using default prompt for CasualLM using apply_chat_template") messages = [ { "role": "user", diff --git a/docling/models/vlm_models_inline/hf_transformers_vision2seq_model.py b/docling/models/vlm_models_inline/hf_transformers_vision2seq_model.py deleted file mode 100644 index f92b9292..00000000 --- a/docling/models/vlm_models_inline/hf_transformers_vision2seq_model.py +++ /dev/null @@ -1,166 +0,0 @@ -import logging -import time -from collections.abc import Iterable -from pathlib import Path -from typing import Optional - -from docling.datamodel.accelerator_options import ( - AcceleratorOptions, -) -from docling.datamodel.base_models import Page, VlmPrediction -from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions -from docling.models.base_model import BasePageModel -from docling.models.utils.hf_model_download import ( - HuggingFaceModelDownloadMixin, -) -from docling.utils.accelerator_utils import decide_device -from docling.utils.profiling import TimeRecorder - -_log = logging.getLogger(__name__) - - -class HuggingFaceVlmModel_AutoModelForVision2Seq( - BasePageModel, HuggingFaceModelDownloadMixin -): - def __init__( - self, - enabled: bool, - artifacts_path: Optional[Path], - accelerator_options: AcceleratorOptions, - vlm_options: InlineVlmOptions, - ): - self.enabled = enabled - - self.vlm_options = vlm_options - - if self.enabled: - import torch - from transformers import ( # type: ignore - AutoModelForVision2Seq, - AutoProcessor, - BitsAndBytesConfig, - ) - - self.device = decide_device( - accelerator_options.device, - supported_devices=vlm_options.supported_devices, - ) - _log.debug(f"Available device for VLM: {self.device}") - - self.use_cache = vlm_options.use_kv_cache - self.max_new_tokens = vlm_options.max_new_tokens - self.temperature = vlm_options.temperature - - repo_cache_folder = vlm_options.repo_id.replace("/", "--") - - # PARAMETERS: - if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) - elif (artifacts_path / repo_cache_folder).exists(): - artifacts_path = artifacts_path / repo_cache_folder - - self.param_quantization_config: Optional[BitsAndBytesConfig] = None - if vlm_options.quantized: - self.param_quantization_config = BitsAndBytesConfig( - load_in_8bit=vlm_options.load_in_8bit, - llm_int8_threshold=vlm_options.llm_int8_threshold, - ) - - self.processor = AutoProcessor.from_pretrained( - artifacts_path, - trust_remote_code=vlm_options.trust_remote_code, - ) - self.vlm_model = AutoModelForVision2Seq.from_pretrained( - artifacts_path, - device_map=self.device, - # torch_dtype=torch.bfloat16, - _attn_implementation=( - "flash_attention_2" - if self.device.startswith("cuda") - and accelerator_options.cuda_use_flash_attention2 - else "eager" - ), - trust_remote_code=vlm_options.trust_remote_code, - ) - - def __call__( - self, conv_res: ConversionResult, page_batch: Iterable[Page] - ) -> Iterable[Page]: - for page in page_batch: - assert page._backend is not None - if not page._backend.is_valid(): - yield page - else: - with TimeRecorder(conv_res, "vlm"): - assert page.size is not None - - hi_res_image = page.get_image(scale=self.vlm_options.scale) - - if hi_res_image is not None: - im_width, im_height = hi_res_image.size - - # populate page_tags with predicted doc tags - page_tags = "" - - """ - if hi_res_image: - if hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") - """ - - # Define prompt structure - prompt = self.formulate_prompt() - - inputs = self.processor( - text=prompt, images=[hi_res_image], return_tensors="pt" - ) - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - start_time = time.time() - # Call model to generate: - generated_ids = self.vlm_model.generate( - **inputs, - max_new_tokens=self.max_new_tokens, - use_cache=self.use_cache, - temperature=self.temperature, - ) - - generation_time = time.time() - start_time - generated_texts = self.processor.batch_decode( - generated_ids[:, inputs["input_ids"].shape[1] :], - skip_special_tokens=False, - )[0] - - num_tokens = len(generated_ids[0]) - page_tags = generated_texts - - _log.debug( - f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." - ) - page.predictions.vlm_response = VlmPrediction( - text=page_tags, - generation_time=generation_time, - ) - - yield page - - def formulate_prompt(self) -> str: - """Formulate a prompt for the VLM.""" - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "This is a page from a document.", - }, - {"type": "image"}, - {"type": "text", "text": self.vlm_options.prompt}, - ], - } - ] - prompt = self.processor.apply_chat_template( - messages, add_generation_prompt=False - ) - return prompt diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index 8f21b9d4..2ecfe55a 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -37,11 +37,8 @@ from docling.datamodel.pipeline_options_vlm_model import ( ) from docling.datamodel.settings import settings from docling.models.api_vlm_model import ApiVlmModel -from docling.models.vlm_models_inline.hf_transformers_causallm_model import ( - HuggingFaceVlmModel_AutoModelForCausalLM, -) -from docling.models.vlm_models_inline.hf_transformers_vision2seq_model import ( - HuggingFaceVlmModel_AutoModelForVision2Seq, +from docling.models.vlm_models_inline.hf_transformers_model import ( + HuggingFaceTransformersVlmModel, ) from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel from docling.pipeline.base_pipeline import PaginatedPipeline @@ -97,25 +94,9 @@ class VlmPipeline(PaginatedPipeline): vlm_options=vlm_options, ), ] - elif ( - vlm_options.inference_framework - == InferenceFramework.TRANSFORMERS_VISION2SEQ - or vlm_options.inference_framework == InferenceFramework.TRANSFORMERS - ): + elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS: self.build_pipe = [ - HuggingFaceVlmModel_AutoModelForVision2Seq( - enabled=True, # must be always enabled for this pipeline to make sense. - artifacts_path=artifacts_path, - accelerator_options=pipeline_options.accelerator_options, - vlm_options=vlm_options, - ), - ] - elif ( - vlm_options.inference_framework - == InferenceFramework.TRANSFORMERS_CAUSALLM - ): - self.build_pipe = [ - HuggingFaceVlmModel_AutoModelForCausalLM( + HuggingFaceTransformersVlmModel( enabled=True, # must be always enabled for this pipeline to make sense. artifacts_path=artifacts_path, accelerator_options=pipeline_options.accelerator_options,