use single HF VLM model class

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-06-02 13:25:51 +02:00
parent 8006683007
commit ea5719c39d
5 changed files with 55 additions and 222 deletions

View File

@ -20,9 +20,13 @@ class ResponseFormat(str, Enum):
class InferenceFramework(str, Enum): class InferenceFramework(str, Enum):
MLX = "mlx" MLX = "mlx"
TRANSFORMERS = "transformers" # TODO: how to flag this as outdated? TRANSFORMERS = "transformers"
TRANSFORMERS_VISION2SEQ = "transformers-vision2seq"
TRANSFORMERS_CAUSALLM = "transformers-causallm"
class TransformersModelType(str, Enum):
AUTOMODEL = "automodel"
AUTOMODEL_VISION2SEQ = "automodel-vision2seq"
AUTOMODEL_CAUSALLM = "automodel-causallm"
class InlineVlmOptions(BaseVlmOptions): class InlineVlmOptions(BaseVlmOptions):
@ -35,6 +39,7 @@ class InlineVlmOptions(BaseVlmOptions):
quantized: bool = False quantized: bool = False
inference_framework: InferenceFramework inference_framework: InferenceFramework
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: ResponseFormat response_format: ResponseFormat
supported_devices: List[AcceleratorDevice] = [ supported_devices: List[AcceleratorDevice] = [

View File

@ -11,6 +11,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
InferenceFramework, InferenceFramework,
InlineVlmOptions, InlineVlmOptions,
ResponseFormat, ResponseFormat,
TransformersModelType,
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -31,7 +32,8 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview", repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.", prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS, response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ, inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[ supported_devices=[
AcceleratorDevice.CPU, AcceleratorDevice.CPU,
AcceleratorDevice.CUDA, AcceleratorDevice.CUDA,
@ -46,7 +48,8 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b", repo_id="ibm-granite/granite-vision-3.2-2b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!", prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ, inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[ supported_devices=[
AcceleratorDevice.CPU, AcceleratorDevice.CPU,
AcceleratorDevice.CUDA, AcceleratorDevice.CUDA,
@ -71,7 +74,8 @@ PIXTRAL_12B_TRANSFORMERS = InlineVlmOptions(
repo_id="mistral-community/pixtral-12b", repo_id="mistral-community/pixtral-12b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!", prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ, inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA], supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
scale=2.0, scale=2.0,
temperature=0.0, temperature=0.0,
@ -93,7 +97,8 @@ PHI4_TRANSFORMERS = InlineVlmOptions(
prompt="Convert this page to MarkDown. Do not miss any text and only output the bare markdown", prompt="Convert this page to MarkDown. Do not miss any text and only output the bare markdown",
trust_remote_code=True, trust_remote_code=True,
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_CAUSALLM, inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_CAUSALLM,
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA], supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
scale=2.0, scale=2.0,
temperature=0.0, temperature=0.0,

View File

@ -3,14 +3,17 @@ import logging
import time import time
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
) )
from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersModelType,
)
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
@ -21,9 +24,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class HuggingFaceVlmModel_AutoModelForCausalLM( class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
BasePageModel, HuggingFaceModelDownloadMixin
):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
@ -37,8 +38,10 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
if self.enabled: if self.enabled:
import torch import torch
from transformers import ( # type: ignore from transformers import (
AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor, AutoProcessor,
BitsAndBytesConfig, BitsAndBytesConfig,
GenerationConfig, GenerationConfig,
@ -77,15 +80,26 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
llm_int8_threshold=vlm_options.llm_int8_threshold, llm_int8_threshold=vlm_options.llm_int8_threshold,
) )
model_cls: Any = AutoModel
if (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_CAUSALLM
):
model_cls = AutoModelForCausalLM
elif (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_VISION2SEQ
):
model_cls = AutoModelForVision2Seq
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
artifacts_path, artifacts_path,
trust_remote_code=vlm_options.trust_remote_code, trust_remote_code=vlm_options.trust_remote_code,
) )
self.vlm_model = AutoModelForCausalLM.from_pretrained( self.vlm_model = model_cls.from_pretrained(
artifacts_path, artifacts_path,
device_map=self.device, device_map=self.device,
torch_dtype="auto", torch_dtype="auto",
quantization_config=self.param_quantization_config,
_attn_implementation=( _attn_implementation=(
"flash_attention_2" "flash_attention_2"
if self.device.startswith("cuda") if self.device.startswith("cuda")
@ -109,51 +123,46 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
with TimeRecorder(conv_res, "vlm"): with TimeRecorder(conv_res, "vlm"):
assert page.size is not None assert page.size is not None
hi_res_image = page.get_image(scale=2) # self.vlm_options.scale) hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# Define prompt structure # Define prompt structure
prompt = self.formulate_prompt() prompt = self.formulate_prompt()
print(f"prompt: '{prompt}', size: {im_width}, {im_height}")
inputs = self.processor( inputs = self.processor(
text=prompt, images=hi_res_image, return_tensors="pt" text=prompt, images=[hi_res_image], return_tensors="pt"
).to(self.device) ).to(self.device)
# Generate response
start_time = time.time() start_time = time.time()
generate_ids = self.vlm_model.generate( # Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs, **inputs,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache, # Enables KV caching which can improve performance use_cache=self.use_cache,
temperature=self.temperature, temperature=self.temperature,
generation_config=self.generation_config, generation_config=self.generation_config,
**self.vlm_options.extra_generation_config, **self.vlm_options.extra_generation_config,
) )
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
num_tokens = len(generate_ids[0])
generation_time = time.time() - start_time generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
response = self.processor.batch_decode( generated_ids[:, inputs["input_ids"].shape[1] :],
generate_ids, skip_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0] )[0]
num_tokens = len(generated_ids[0])
_log.debug( _log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
) )
page.predictions.vlm_response = VlmPrediction( page.predictions.vlm_response = VlmPrediction(
text=response, generation_time=generation_time text=generated_texts,
generation_time=generation_time,
) )
yield page yield page
def formulate_prompt(self) -> str: def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM.""" """Formulate a prompt for the VLM."""
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct": if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4") _log.debug("Using specialized prompt for Phi-4")
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally # more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
@ -167,7 +176,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
return prompt return prompt
_log.debug("Using default prompt for CasualLM using apply_chat_template")
messages = [ messages = [
{ {
"role": "user", "role": "user",

View File

@ -1,166 +0,0 @@
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceVlmModel_AutoModelForVision2Seq(
BasePageModel, HuggingFaceModelDownloadMixin
):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
import torch
from transformers import ( # type: ignore
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
)
self.device = decide_device(
accelerator_options.device,
supported_devices=vlm_options.supported_devices,
)
_log.debug(f"Available device for VLM: {self.device}")
self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_quantization_config: Optional[BitsAndBytesConfig] = None
if vlm_options.quantized:
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit,
llm_int8_threshold=vlm_options.llm_int8_threshold,
)
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=self.device,
# torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
trust_remote_code=vlm_options.trust_remote_code,
)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
"""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
"""
# Define prompt structure
prompt = self.formulate_prompt()
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
page_tags = generated_texts
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
)
yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.vlm_options.prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
return prompt

View File

@ -37,11 +37,8 @@ from docling.datamodel.pipeline_options_vlm_model import (
) )
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.api_vlm_model import ApiVlmModel from docling.models.api_vlm_model import ApiVlmModel
from docling.models.vlm_models_inline.hf_transformers_causallm_model import ( from docling.models.vlm_models_inline.hf_transformers_model import (
HuggingFaceVlmModel_AutoModelForCausalLM, HuggingFaceTransformersVlmModel,
)
from docling.models.vlm_models_inline.hf_transformers_vision2seq_model import (
HuggingFaceVlmModel_AutoModelForVision2Seq,
) )
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
from docling.pipeline.base_pipeline import PaginatedPipeline from docling.pipeline.base_pipeline import PaginatedPipeline
@ -97,25 +94,9 @@ class VlmPipeline(PaginatedPipeline):
vlm_options=vlm_options, vlm_options=vlm_options,
), ),
] ]
elif ( elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
vlm_options.inference_framework
== InferenceFramework.TRANSFORMERS_VISION2SEQ
or vlm_options.inference_framework == InferenceFramework.TRANSFORMERS
):
self.build_pipe = [ self.build_pipe = [
HuggingFaceVlmModel_AutoModelForVision2Seq( HuggingFaceTransformersVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
),
]
elif (
vlm_options.inference_framework
== InferenceFramework.TRANSFORMERS_CAUSALLM
):
self.build_pipe = [
HuggingFaceVlmModel_AutoModelForCausalLM(
enabled=True, # must be always enabled for this pipeline to make sense. enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options, accelerator_options=pipeline_options.accelerator_options,