use single HF VLM model class

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-06-02 13:25:51 +02:00
parent 8006683007
commit ea5719c39d
5 changed files with 55 additions and 222 deletions

View File

@ -20,9 +20,13 @@ class ResponseFormat(str, Enum):
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers" # TODO: how to flag this as outdated?
TRANSFORMERS_VISION2SEQ = "transformers-vision2seq"
TRANSFORMERS_CAUSALLM = "transformers-causallm"
TRANSFORMERS = "transformers"
class TransformersModelType(str, Enum):
AUTOMODEL = "automodel"
AUTOMODEL_VISION2SEQ = "automodel-vision2seq"
AUTOMODEL_CAUSALLM = "automodel-causallm"
class InlineVlmOptions(BaseVlmOptions):
@ -35,6 +39,7 @@ class InlineVlmOptions(BaseVlmOptions):
quantized: bool = False
inference_framework: InferenceFramework
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: ResponseFormat
supported_devices: List[AcceleratorDevice] = [

View File

@ -11,6 +11,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
InferenceFramework,
InlineVlmOptions,
ResponseFormat,
TransformersModelType,
)
_log = logging.getLogger(__name__)
@ -31,7 +32,8 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
@ -46,7 +48,8 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
@ -71,7 +74,8 @@ PIXTRAL_12B_TRANSFORMERS = InlineVlmOptions(
repo_id="mistral-community/pixtral-12b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_VISION2SEQ,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
scale=2.0,
temperature=0.0,
@ -93,7 +97,8 @@ PHI4_TRANSFORMERS = InlineVlmOptions(
prompt="Convert this page to MarkDown. Do not miss any text and only output the bare markdown",
trust_remote_code=True,
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_CAUSALLM,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_CAUSALLM,
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
scale=2.0,
temperature=0.0,

View File

@ -3,14 +3,17 @@ import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersModelType,
)
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
@ -21,9 +24,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceVlmModel_AutoModelForCausalLM(
BasePageModel, HuggingFaceModelDownloadMixin
):
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@ -37,8 +38,10 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
if self.enabled:
import torch
from transformers import ( # type: ignore
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
GenerationConfig,
@ -77,15 +80,26 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
llm_int8_threshold=vlm_options.llm_int8_threshold,
)
model_cls: Any = AutoModel
if (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_CAUSALLM
):
model_cls = AutoModelForCausalLM
elif (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_VISION2SEQ
):
model_cls = AutoModelForVision2Seq
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
self.vlm_model = AutoModelForCausalLM.from_pretrained(
self.vlm_model = model_cls.from_pretrained(
artifacts_path,
device_map=self.device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
@ -109,51 +123,46 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=2) # self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
hi_res_image = page.get_image(scale=self.vlm_options.scale)
# Define prompt structure
prompt = self.formulate_prompt()
print(f"prompt: '{prompt}', size: {im_width}, {im_height}")
inputs = self.processor(
text=prompt, images=hi_res_image, return_tensors="pt"
text=prompt, images=[hi_res_image], return_tensors="pt"
).to(self.device)
# Generate response
start_time = time.time()
generate_ids = self.vlm_model.generate(
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache, # Enables KV caching which can improve performance
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
num_tokens = len(generate_ids[0])
generation_time = time.time() - start_time
response = self.processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
page.predictions.vlm_response = VlmPrediction(
text=response, generation_time=generation_time
text=generated_texts,
generation_time=generation_time,
)
yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
@ -167,7 +176,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(
return prompt
_log.debug("Using default prompt for CasualLM using apply_chat_template")
messages = [
{
"role": "user",

View File

@ -1,166 +0,0 @@
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceVlmModel_AutoModelForVision2Seq(
BasePageModel, HuggingFaceModelDownloadMixin
):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
import torch
from transformers import ( # type: ignore
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
)
self.device = decide_device(
accelerator_options.device,
supported_devices=vlm_options.supported_devices,
)
_log.debug(f"Available device for VLM: {self.device}")
self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_quantization_config: Optional[BitsAndBytesConfig] = None
if vlm_options.quantized:
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit,
llm_int8_threshold=vlm_options.llm_int8_threshold,
)
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=self.device,
# torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
trust_remote_code=vlm_options.trust_remote_code,
)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
"""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
"""
# Define prompt structure
prompt = self.formulate_prompt()
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
page_tags = generated_texts
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
)
yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.vlm_options.prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
return prompt

View File

@ -37,11 +37,8 @@ from docling.datamodel.pipeline_options_vlm_model import (
)
from docling.datamodel.settings import settings
from docling.models.api_vlm_model import ApiVlmModel
from docling.models.vlm_models_inline.hf_transformers_causallm_model import (
HuggingFaceVlmModel_AutoModelForCausalLM,
)
from docling.models.vlm_models_inline.hf_transformers_vision2seq_model import (
HuggingFaceVlmModel_AutoModelForVision2Seq,
from docling.models.vlm_models_inline.hf_transformers_model import (
HuggingFaceTransformersVlmModel,
)
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
from docling.pipeline.base_pipeline import PaginatedPipeline
@ -97,25 +94,9 @@ class VlmPipeline(PaginatedPipeline):
vlm_options=vlm_options,
),
]
elif (
vlm_options.inference_framework
== InferenceFramework.TRANSFORMERS_VISION2SEQ
or vlm_options.inference_framework == InferenceFramework.TRANSFORMERS
):
elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
self.build_pipe = [
HuggingFaceVlmModel_AutoModelForVision2Seq(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
),
]
elif (
vlm_options.inference_framework
== InferenceFramework.TRANSFORMERS_CAUSALLM
):
self.build_pipe = [
HuggingFaceVlmModel_AutoModelForCausalLM(
HuggingFaceTransformersVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,