mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
fix: Update Transformers & VLLM inference code, CLI and VLM specs (#2322)
* Update VLLM inference code, CLI and VLM specs Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix generation and decoder args for HF model Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix vllm device args Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Cleanup Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Bugfixes Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -66,6 +66,7 @@ from docling.datamodel.vlm_model_specs import (
|
|||||||
GRANITE_VISION_TRANSFORMERS,
|
GRANITE_VISION_TRANSFORMERS,
|
||||||
GRANITEDOCLING_MLX,
|
GRANITEDOCLING_MLX,
|
||||||
GRANITEDOCLING_TRANSFORMERS,
|
GRANITEDOCLING_TRANSFORMERS,
|
||||||
|
GRANITEDOCLING_VLLM,
|
||||||
SMOLDOCLING_MLX,
|
SMOLDOCLING_MLX,
|
||||||
SMOLDOCLING_TRANSFORMERS,
|
SMOLDOCLING_TRANSFORMERS,
|
||||||
SMOLDOCLING_VLLM,
|
SMOLDOCLING_VLLM,
|
||||||
@@ -686,6 +687,7 @@ def convert( # noqa: C901
|
|||||||
"To run SmolDocling faster, please install mlx-vlm:\n"
|
"To run SmolDocling faster, please install mlx-vlm:\n"
|
||||||
"pip install mlx-vlm"
|
"pip install mlx-vlm"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif vlm_model == VlmModelType.GRANITEDOCLING:
|
elif vlm_model == VlmModelType.GRANITEDOCLING:
|
||||||
pipeline_options.vlm_options = GRANITEDOCLING_TRANSFORMERS
|
pipeline_options.vlm_options = GRANITEDOCLING_TRANSFORMERS
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
@@ -701,6 +703,9 @@ def convert( # noqa: C901
|
|||||||
elif vlm_model == VlmModelType.SMOLDOCLING_VLLM:
|
elif vlm_model == VlmModelType.SMOLDOCLING_VLLM:
|
||||||
pipeline_options.vlm_options = SMOLDOCLING_VLLM
|
pipeline_options.vlm_options = SMOLDOCLING_VLLM
|
||||||
|
|
||||||
|
elif vlm_model == VlmModelType.GRANITEDOCLING_VLLM:
|
||||||
|
pipeline_options.vlm_options = GRANITEDOCLING_VLLM
|
||||||
|
|
||||||
pdf_format_option = PdfFormatOption(
|
pdf_format_option = PdfFormatOption(
|
||||||
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
|||||||
kind: Literal["inline_model_options"] = "inline_model_options"
|
kind: Literal["inline_model_options"] = "inline_model_options"
|
||||||
|
|
||||||
repo_id: str
|
repo_id: str
|
||||||
|
revision: str = "main"
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
load_in_8bit: bool = True
|
load_in_8bit: bool = True
|
||||||
llm_int8_threshold: float = 6.0
|
llm_int8_threshold: float = 6.0
|
||||||
|
|||||||
@@ -29,12 +29,20 @@ GRANITEDOCLING_TRANSFORMERS = InlineVlmOptions(
|
|||||||
AcceleratorDevice.CPU,
|
AcceleratorDevice.CPU,
|
||||||
AcceleratorDevice.CUDA,
|
AcceleratorDevice.CUDA,
|
||||||
],
|
],
|
||||||
|
extra_generation_config=dict(skip_special_tokens=False),
|
||||||
scale=2.0,
|
scale=2.0,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_new_tokens=8192,
|
max_new_tokens=8192,
|
||||||
stop_strings=["</doctag>", "<|end_of_text|>"],
|
stop_strings=["</doctag>", "<|end_of_text|>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GRANITEDOCLING_VLLM = GRANITEDOCLING_TRANSFORMERS.model_copy()
|
||||||
|
GRANITEDOCLING_VLLM.inference_framework = InferenceFramework.VLLM
|
||||||
|
GRANITEDOCLING_VLLM.revision = (
|
||||||
|
"untied" # change back to "main" with next vllm relase after 0.10.2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
GRANITEDOCLING_MLX = InlineVlmOptions(
|
GRANITEDOCLING_MLX = InlineVlmOptions(
|
||||||
repo_id="ibm-granite/granite-docling-258M-mlx",
|
repo_id="ibm-granite/granite-docling-258M-mlx",
|
||||||
prompt="Convert this page to docling.",
|
prompt="Convert this page to docling.",
|
||||||
@@ -302,3 +310,4 @@ class VlmModelType(str, Enum):
|
|||||||
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
||||||
GOT_OCR_2 = "got_ocr_2"
|
GOT_OCR_2 = "got_ocr_2"
|
||||||
GRANITEDOCLING = "granite_docling"
|
GRANITEDOCLING = "granite_docling"
|
||||||
|
GRANITEDOCLING_VLLM = "granite_docling_vllm"
|
||||||
|
|||||||
@@ -88,7 +88,8 @@ class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
|||||||
|
|
||||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||||
return user_prompt
|
return user_prompt
|
||||||
|
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
|
||||||
|
return ""
|
||||||
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||||
_log.debug("Using specialized prompt for Phi-4")
|
_log.debug("Using specialized prompt for Phi-4")
|
||||||
# Note: This might need adjustment for VLLM vs transformers
|
# Note: This might need adjustment for VLLM vs transformers
|
||||||
|
|||||||
@@ -34,7 +34,12 @@ class HuggingFaceModelDownloadMixin:
|
|||||||
local_dir: Optional[Path] = None,
|
local_dir: Optional[Path] = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
progress: bool = False,
|
progress: bool = False,
|
||||||
|
revision: Optional[str] = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
return download_hf_model(
|
return download_hf_model(
|
||||||
repo_id=repo_id, local_dir=local_dir, force=force, progress=progress
|
repo_id=repo_id,
|
||||||
|
local_dir=local_dir,
|
||||||
|
force=force,
|
||||||
|
progress=progress,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -75,7 +75,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
|
|
||||||
if artifacts_path is None:
|
if artifacts_path is None:
|
||||||
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
artifacts_path = self.download_models(
|
||||||
|
self.vlm_options.repo_id, revision=self.vlm_options.revision
|
||||||
|
)
|
||||||
elif (artifacts_path / repo_cache_folder).exists():
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
artifacts_path = artifacts_path / repo_cache_folder
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
@@ -106,6 +108,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
trust_remote_code=vlm_options.trust_remote_code,
|
trust_remote_code=vlm_options.trust_remote_code,
|
||||||
|
revision=vlm_options.revision,
|
||||||
)
|
)
|
||||||
self.processor.tokenizer.padding_side = "left"
|
self.processor.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
@@ -120,11 +123,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
else "sdpa"
|
else "sdpa"
|
||||||
),
|
),
|
||||||
trust_remote_code=vlm_options.trust_remote_code,
|
trust_remote_code=vlm_options.trust_remote_code,
|
||||||
|
revision=vlm_options.revision,
|
||||||
)
|
)
|
||||||
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
||||||
|
|
||||||
# Load generation config
|
# Load generation config
|
||||||
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
self.generation_config = GenerationConfig.from_pretrained(
|
||||||
|
artifacts_path, revision=vlm_options.revision
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
@@ -196,7 +202,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
|
# -- Normalize images to RGB PIL
|
||||||
pil_images: list[Image] = []
|
pil_images: list[Image] = []
|
||||||
for img in image_batch:
|
for img in image_batch:
|
||||||
if isinstance(img, np.ndarray):
|
if isinstance(img, np.ndarray):
|
||||||
@@ -258,13 +264,30 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -- Filter out decoder-specific keys from extra_generation_config
|
||||||
|
decoder_keys = {
|
||||||
|
"skip_special_tokens",
|
||||||
|
"clean_up_tokenization_spaces",
|
||||||
|
"spaces_between_special_tokens",
|
||||||
|
}
|
||||||
|
generation_config = {
|
||||||
|
k: v
|
||||||
|
for k, v in self.vlm_options.extra_generation_config.items()
|
||||||
|
if k not in decoder_keys
|
||||||
|
}
|
||||||
|
decoder_config = {
|
||||||
|
k: v
|
||||||
|
for k, v in self.vlm_options.extra_generation_config.items()
|
||||||
|
if k in decoder_keys
|
||||||
|
}
|
||||||
|
|
||||||
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
**inputs,
|
**inputs,
|
||||||
"max_new_tokens": self.max_new_tokens,
|
"max_new_tokens": self.max_new_tokens,
|
||||||
"use_cache": self.use_cache,
|
"use_cache": self.use_cache,
|
||||||
"generation_config": self.generation_config,
|
"generation_config": self.generation_config,
|
||||||
**self.vlm_options.extra_generation_config,
|
**generation_config,
|
||||||
}
|
}
|
||||||
if self.temperature > 0:
|
if self.temperature > 0:
|
||||||
gen_kwargs["do_sample"] = True
|
gen_kwargs["do_sample"] = True
|
||||||
@@ -293,7 +316,8 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
)
|
)
|
||||||
|
|
||||||
decoded_texts: list[str] = decode_fn(
|
decoded_texts: list[str] = decode_fn(
|
||||||
trimmed_sequences, skip_special_tokens=False
|
trimmed_sequences,
|
||||||
|
**decoder_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# -- Clip off pad tokens from decoded texts
|
# -- Clip off pad tokens from decoded texts
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
if artifacts_path is None:
|
if artifacts_path is None:
|
||||||
artifacts_path = self.download_models(
|
artifacts_path = self.download_models(
|
||||||
self.vlm_options.repo_id,
|
self.vlm_options.repo_id,
|
||||||
|
revision=self.vlm_options.revision,
|
||||||
)
|
)
|
||||||
elif (artifacts_path / repo_cache_folder).exists():
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
artifacts_path = artifacts_path / repo_cache_folder
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||||
AcceleratorOptions,
|
|
||||||
)
|
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
from docling.datamodel.base_models import Page, VlmPrediction
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
@@ -17,9 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|||||||
TransformersPromptStyle,
|
TransformersPromptStyle,
|
||||||
)
|
)
|
||||||
from docling.models.base_model import BaseVlmPageModel
|
from docling.models.base_model import BaseVlmPageModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin
|
||||||
HuggingFaceModelDownloadMixin,
|
|
||||||
)
|
|
||||||
from docling.utils.accelerator_utils import decide_device
|
from docling.utils.accelerator_utils import decide_device
|
||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
@@ -27,6 +23,62 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||||
|
"""
|
||||||
|
vLLM-backed vision-language model that accepts PIL images (or numpy arrays)
|
||||||
|
via vLLM's multi_modal_data, with prompt formatting handled by formulate_prompt().
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --------- Allowlist of vLLM args ---------
|
||||||
|
# SamplingParams (runtime generation controls)
|
||||||
|
_VLLM_SAMPLING_KEYS = {
|
||||||
|
# Core
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
# Penalties
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"repetition_penalty",
|
||||||
|
# Stops / outputs
|
||||||
|
"stop",
|
||||||
|
"stop_token_ids",
|
||||||
|
"skip_special_tokens",
|
||||||
|
"spaces_between_special_tokens",
|
||||||
|
# Search / length
|
||||||
|
"n",
|
||||||
|
"best_of",
|
||||||
|
"length_penalty",
|
||||||
|
"early_stopping",
|
||||||
|
# Misc
|
||||||
|
"logprobs",
|
||||||
|
"prompt_logprobs",
|
||||||
|
"min_p",
|
||||||
|
"seed",
|
||||||
|
}
|
||||||
|
|
||||||
|
# LLM(...) / EngineArgs (engine/load-time controls)
|
||||||
|
_VLLM_ENGINE_KEYS = {
|
||||||
|
# Model/tokenizer/impl
|
||||||
|
"tokenizer",
|
||||||
|
"tokenizer_mode",
|
||||||
|
"download_dir",
|
||||||
|
# Parallelism / memory / lengths
|
||||||
|
"tensor_parallel_size",
|
||||||
|
"pipeline_parallel_size",
|
||||||
|
"gpu_memory_utilization",
|
||||||
|
"max_model_len",
|
||||||
|
"max_num_batched_tokens",
|
||||||
|
"kv_cache_dtype",
|
||||||
|
"dtype",
|
||||||
|
# Quantization (coarse switch)
|
||||||
|
"quantization",
|
||||||
|
# Multimodal limits
|
||||||
|
"limit_mm_per_prompt",
|
||||||
|
# Execution toggles
|
||||||
|
"enforce_eager",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
@@ -35,120 +87,147 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
vlm_options: InlineVlmOptions,
|
vlm_options: InlineVlmOptions,
|
||||||
):
|
):
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
|
|
||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
if self.enabled:
|
self.llm = None
|
||||||
from transformers import AutoProcessor
|
self.sampling_params = None
|
||||||
from vllm import LLM, SamplingParams
|
self.processor = None # used for CHAT templating in formulate_prompt()
|
||||||
|
self.device = "cpu"
|
||||||
|
self.max_new_tokens = vlm_options.max_new_tokens
|
||||||
|
self.temperature = vlm_options.temperature
|
||||||
|
|
||||||
self.device = decide_device(
|
if not self.enabled:
|
||||||
accelerator_options.device,
|
return
|
||||||
supported_devices=vlm_options.supported_devices,
|
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Device selection
|
||||||
|
self.device = decide_device(
|
||||||
|
accelerator_options.device, supported_devices=vlm_options.supported_devices
|
||||||
|
)
|
||||||
|
_log.debug(f"Available device for VLM: {self.device}")
|
||||||
|
|
||||||
|
# Resolve artifacts path / cache folder
|
||||||
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
|
if artifacts_path is None:
|
||||||
|
artifacts_path = self.download_models(
|
||||||
|
self.vlm_options.repo_id, revision=self.vlm_options.revision
|
||||||
)
|
)
|
||||||
_log.debug(f"Available device for VLM: {self.device}")
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
self.max_new_tokens = vlm_options.max_new_tokens
|
# --------- Strict split & validation of extra_generation_config ---------
|
||||||
self.temperature = vlm_options.temperature
|
extra_cfg = self.vlm_options.extra_generation_config
|
||||||
|
|
||||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
load_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_ENGINE_KEYS}
|
||||||
|
gen_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_SAMPLING_KEYS}
|
||||||
|
|
||||||
if artifacts_path is None:
|
unknown = sorted(
|
||||||
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
k
|
||||||
elif (artifacts_path / repo_cache_folder).exists():
|
for k in extra_cfg.keys()
|
||||||
artifacts_path = artifacts_path / repo_cache_folder
|
if k not in self._VLLM_ENGINE_KEYS and k not in self._VLLM_SAMPLING_KEYS
|
||||||
|
)
|
||||||
# Initialize VLLM LLM
|
if unknown:
|
||||||
llm_kwargs: Dict[str, Any] = {
|
_log.warning(
|
||||||
"model": str(artifacts_path),
|
"Ignoring unknown extra_generation_config keys for vLLM: %s", unknown
|
||||||
"limit_mm_per_prompt": {"image": 1},
|
|
||||||
"trust_remote_code": vlm_options.trust_remote_code,
|
|
||||||
"model_impl": "transformers",
|
|
||||||
"gpu_memory_utilization": 0.3, # hardcoded for now, leaves room for ~3 different models.
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add device-specific configurations
|
|
||||||
|
|
||||||
if self.device == "cpu":
|
|
||||||
llm_kwargs["device"] = "cpu"
|
|
||||||
|
|
||||||
# Add quantization if specified
|
|
||||||
if vlm_options.quantized:
|
|
||||||
if vlm_options.load_in_8bit:
|
|
||||||
llm_kwargs["quantization"] = "bitsandbytes"
|
|
||||||
|
|
||||||
self.llm = LLM(**llm_kwargs)
|
|
||||||
|
|
||||||
# Initialize processor for prompt formatting
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
artifacts_path,
|
|
||||||
trust_remote_code=vlm_options.trust_remote_code,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up sampling parameters
|
# --------- Construct LLM kwargs (engine/load-time) ---------
|
||||||
self.sampling_params = SamplingParams(
|
llm_kwargs: Dict[str, Any] = {
|
||||||
temperature=self.temperature,
|
"model": str(artifacts_path),
|
||||||
max_tokens=self.max_new_tokens,
|
"model_impl": "transformers",
|
||||||
stop=vlm_options.stop_strings if vlm_options.stop_strings else None,
|
"limit_mm_per_prompt": {"image": 1},
|
||||||
**vlm_options.extra_generation_config,
|
"revision": self.vlm_options.revision,
|
||||||
)
|
"trust_remote_code": self.vlm_options.trust_remote_code,
|
||||||
|
**load_cfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.device == "cpu":
|
||||||
|
llm_kwargs.setdefault("enforce_eager", True)
|
||||||
|
else:
|
||||||
|
llm_kwargs.setdefault(
|
||||||
|
"gpu_memory_utilization", 0.3
|
||||||
|
) # room for other models
|
||||||
|
|
||||||
|
# Quantization (kept as-is; coarse)
|
||||||
|
if self.vlm_options.quantized and self.vlm_options.load_in_8bit:
|
||||||
|
llm_kwargs.setdefault("quantization", "bitsandbytes")
|
||||||
|
|
||||||
|
# Initialize vLLM LLM
|
||||||
|
self.llm = LLM(**llm_kwargs)
|
||||||
|
|
||||||
|
# Initialize processor for prompt templating (needed for CHAT style)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
artifacts_path,
|
||||||
|
trust_remote_code=self.vlm_options.trust_remote_code,
|
||||||
|
revision=self.vlm_options.revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --------- SamplingParams (runtime) ---------
|
||||||
|
self.sampling_params = SamplingParams(
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_new_tokens,
|
||||||
|
stop=(self.vlm_options.stop_strings or None),
|
||||||
|
**gen_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
|
# If disabled, pass-through
|
||||||
|
if not self.enabled:
|
||||||
|
for page in page_batch:
|
||||||
|
yield page
|
||||||
|
return
|
||||||
|
|
||||||
page_list = list(page_batch)
|
page_list = list(page_batch)
|
||||||
if not page_list:
|
if not page_list:
|
||||||
return
|
return
|
||||||
|
|
||||||
valid_pages = []
|
# Preserve original order
|
||||||
invalid_pages = []
|
original_order = page_list[:]
|
||||||
|
|
||||||
|
# Separate valid/invalid
|
||||||
|
valid_pages: list[Page] = []
|
||||||
|
invalid_pages: list[Page] = []
|
||||||
for page in page_list:
|
for page in page_list:
|
||||||
assert page._backend is not None
|
assert page._backend is not None
|
||||||
if not page._backend.is_valid():
|
if page._backend.is_valid():
|
||||||
invalid_pages.append(page)
|
|
||||||
else:
|
|
||||||
valid_pages.append(page)
|
valid_pages.append(page)
|
||||||
|
else:
|
||||||
|
invalid_pages.append(page)
|
||||||
|
|
||||||
# Process valid pages in batch
|
|
||||||
if valid_pages:
|
if valid_pages:
|
||||||
with TimeRecorder(conv_res, "vlm"):
|
with TimeRecorder(conv_res, "vlm"):
|
||||||
# Prepare images and prompts for batch processing
|
images: list[Image] = []
|
||||||
images = []
|
user_prompts: list[str] = []
|
||||||
user_prompts = []
|
pages_with_images: list[Page] = []
|
||||||
pages_with_images = []
|
|
||||||
|
|
||||||
for page in valid_pages:
|
for page in valid_pages:
|
||||||
assert page.size is not None
|
assert page.size is not None
|
||||||
hi_res_image = page.get_image(
|
hi_res_image = page.get_image(
|
||||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
scale=self.vlm_options.scale,
|
||||||
|
max_size=self.vlm_options.max_size,
|
||||||
)
|
)
|
||||||
|
if hi_res_image is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# Only process pages with valid images
|
images.append(hi_res_image)
|
||||||
if hi_res_image is not None:
|
|
||||||
images.append(hi_res_image)
|
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
if callable(self.vlm_options.prompt):
|
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
|
||||||
else:
|
|
||||||
user_prompt = self.vlm_options.prompt
|
|
||||||
|
|
||||||
user_prompts.append(user_prompt)
|
user_prompts.append(user_prompt)
|
||||||
pages_with_images.append(page)
|
pages_with_images.append(page)
|
||||||
|
|
||||||
# Use process_images for the actual inference
|
if images:
|
||||||
if images: # Only if we have valid images
|
|
||||||
predictions = list(self.process_images(images, user_prompts))
|
predictions = list(self.process_images(images, user_prompts))
|
||||||
|
|
||||||
# Attach results to pages
|
|
||||||
for page, prediction in zip(pages_with_images, predictions):
|
for page, prediction in zip(pages_with_images, predictions):
|
||||||
page.predictions.vlm_response = prediction
|
page.predictions.vlm_response = prediction
|
||||||
|
|
||||||
# Yield all pages (valid and invalid)
|
# Yield in original order
|
||||||
for page in invalid_pages:
|
for page in original_order:
|
||||||
yield page
|
|
||||||
for page in valid_pages:
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
def process_images(
|
def process_images(
|
||||||
@@ -156,50 +235,33 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
prompt: Union[str, list[str]],
|
prompt: Union[str, list[str]],
|
||||||
) -> Iterable[VlmPrediction]:
|
) -> Iterable[VlmPrediction]:
|
||||||
"""Process raw images without page metadata in a single batched inference call.
|
"""Process images in a single batched vLLM inference call."""
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
Args:
|
# -- Normalize images to RGB PIL
|
||||||
image_batch: Iterable of PIL Images or numpy arrays
|
|
||||||
prompt: Either:
|
|
||||||
- str: Single prompt used for all images
|
|
||||||
- list[str]: List of prompts (one per image, must match image count)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If prompt list length doesn't match image count.
|
|
||||||
"""
|
|
||||||
pil_images: list[Image] = []
|
pil_images: list[Image] = []
|
||||||
|
|
||||||
for img in image_batch:
|
for img in image_batch:
|
||||||
# Convert numpy array to PIL Image if needed
|
|
||||||
if isinstance(img, np.ndarray):
|
if isinstance(img, np.ndarray):
|
||||||
if img.ndim == 3 and img.shape[2] in [3, 4]:
|
if img.ndim == 3 and img.shape[2] in (3, 4):
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||||
elif img.ndim == 2:
|
elif img.ndim == 2:
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||||
else:
|
else:
|
||||||
pil_img = img
|
pil_img = img
|
||||||
|
|
||||||
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
|
||||||
if pil_img.mode != "RGB":
|
if pil_img.mode != "RGB":
|
||||||
pil_img = pil_img.convert("RGB")
|
pil_img = pil_img.convert("RGB")
|
||||||
|
|
||||||
pil_images.append(pil_img)
|
pil_images.append(pil_img)
|
||||||
|
|
||||||
if len(pil_images) == 0:
|
if not pil_images:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle prompt parameter
|
# Normalize prompts
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
# Single prompt for all images
|
|
||||||
user_prompts = [prompt] * len(pil_images)
|
user_prompts = [prompt] * len(pil_images)
|
||||||
elif isinstance(prompt, list):
|
elif isinstance(prompt, list):
|
||||||
# List of prompts (one per image)
|
|
||||||
if len(prompt) != len(pil_images):
|
if len(prompt) != len(pil_images):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||||
@@ -208,28 +270,31 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
||||||
|
|
||||||
# Format prompts individually
|
# Format prompts
|
||||||
prompts: list[str] = [
|
prompts: list[str] = [self.formulate_prompt(up) for up in user_prompts]
|
||||||
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
|
|
||||||
|
# Build vLLM inputs
|
||||||
|
llm_inputs = [
|
||||||
|
{"prompt": p, "multi_modal_data": {"image": im}}
|
||||||
|
for p, im in zip(prompts, pil_images)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Prepare VLLM inputs
|
# Generate
|
||||||
llm_inputs = []
|
assert self.llm is not None and self.sampling_params is not None
|
||||||
for prompt, image in zip(prompts, pil_images):
|
|
||||||
llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
|
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
|
||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
# Logging tokens count for the first sample as a representative metric
|
# Optional debug
|
||||||
if len(outputs) > 0:
|
if outputs:
|
||||||
num_tokens = len(outputs[0].outputs[0].token_ids)
|
try:
|
||||||
_log.debug(
|
num_tokens = len(outputs[0].outputs[0].token_ids)
|
||||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
|
||||||
)
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Emit predictions
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
# Apply decode_response to the output text
|
text = output.outputs[0].text if output.outputs else ""
|
||||||
decoded_text = self.vlm_options.decode_response(output.outputs[0].text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||||
|
|||||||
Reference in New Issue
Block a user