mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat: batching support for VLMs in transformers backend, add initial VLLM backend (#2094)
* Prepare existing codes for use with new multi-stage VLM pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add multithreaded VLM pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add VLM task interpreters Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add VLM task interpreters Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Remove prints Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix KeyboardInterrupt behaviour Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add VLLM backend support, optimize process_images Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Tweak defaults Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Implement proper batch inference for HuggingFaceTransformersVlmModel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Small fixes Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Cleanup hf_transformers_model batching impl Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Adjust example instatiation of multi-stage VLM pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add GoT OCR 2.0 Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Factor out changes without multi-stage pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Reset defaults for generation Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Cleanup Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add torch.compile, fix temperature setting in gen_kwargs Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Expose page_batch_size on CLI Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add torch_dtype bfloat16 to SMOLDOCLING and SMOLVLM model spec Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Clip off pad_token Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -60,10 +60,12 @@ from docling.datamodel.pipeline_options import (
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.datamodel.vlm_model_specs import (
|
||||
GOT2_TRANSFORMERS,
|
||||
GRANITE_VISION_OLLAMA,
|
||||
GRANITE_VISION_TRANSFORMERS,
|
||||
SMOLDOCLING_MLX,
|
||||
SMOLDOCLING_TRANSFORMERS,
|
||||
SMOLDOCLING_VLLM,
|
||||
VlmModelType,
|
||||
)
|
||||
from docling.document_converter import (
|
||||
@@ -477,6 +479,13 @@ def convert( # noqa: C901
|
||||
"--logo", callback=logo_callback, is_eager=True, help="Docling logo"
|
||||
),
|
||||
] = None,
|
||||
page_batch_size: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--page-batch-size",
|
||||
help=f"Number of pages processed in one batch. Default: {settings.perf.page_batch_size}",
|
||||
),
|
||||
] = settings.perf.page_batch_size,
|
||||
):
|
||||
log_format = "%(asctime)s\t%(levelname)s\t%(name)s: %(message)s"
|
||||
|
||||
@@ -491,6 +500,7 @@ def convert( # noqa: C901
|
||||
settings.debug.visualize_layout = debug_visualize_layout
|
||||
settings.debug.visualize_tables = debug_visualize_tables
|
||||
settings.debug.visualize_ocr = debug_visualize_ocr
|
||||
settings.perf.page_batch_size = page_batch_size
|
||||
|
||||
if from_formats is None:
|
||||
from_formats = list(InputFormat)
|
||||
@@ -631,6 +641,8 @@ def convert( # noqa: C901
|
||||
pipeline_options.vlm_options = GRANITE_VISION_TRANSFORMERS
|
||||
elif vlm_model == VlmModelType.GRANITE_VISION_OLLAMA:
|
||||
pipeline_options.vlm_options = GRANITE_VISION_OLLAMA
|
||||
elif vlm_model == VlmModelType.GOT_OCR_2:
|
||||
pipeline_options.vlm_options = GOT2_TRANSFORMERS
|
||||
elif vlm_model == VlmModelType.SMOLDOCLING:
|
||||
pipeline_options.vlm_options = SMOLDOCLING_TRANSFORMERS
|
||||
if sys.platform == "darwin":
|
||||
@@ -643,6 +655,8 @@ def convert( # noqa: C901
|
||||
"To run SmolDocling faster, please install mlx-vlm:\n"
|
||||
"pip install mlx-vlm"
|
||||
)
|
||||
elif vlm_model == VlmModelType.SMOLDOCLING_VLLM:
|
||||
pipeline_options.vlm_options = SMOLDOCLING_VLLM
|
||||
|
||||
pdf_format_option = PdfFormatOption(
|
||||
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import (
|
||||
|
||||
@@ -282,6 +282,9 @@ class LayoutOptions(BaseModel):
|
||||
keep_empty_clusters: bool = (
|
||||
False # Whether to keep clusters that contain no text cells
|
||||
)
|
||||
skip_cell_assignment: bool = (
|
||||
False # Skip cell-to-cluster assignment for VLM-only processing
|
||||
)
|
||||
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2
|
||||
|
||||
|
||||
|
||||
@@ -26,11 +26,14 @@ class ResponseFormat(str, Enum):
|
||||
DOCTAGS = "doctags"
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
OTSL = "otsl"
|
||||
PLAINTEXT = "plaintext"
|
||||
|
||||
|
||||
class InferenceFramework(str, Enum):
|
||||
MLX = "mlx"
|
||||
TRANSFORMERS = "transformers"
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
class TransformersModelType(str, Enum):
|
||||
@@ -43,6 +46,7 @@ class TransformersModelType(str, Enum):
|
||||
class TransformersPromptStyle(str, Enum):
|
||||
CHAT = "chat"
|
||||
RAW = "raw"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class InlineVlmOptions(BaseVlmOptions):
|
||||
@@ -68,6 +72,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
||||
|
||||
stop_strings: List[str] = []
|
||||
extra_generation_config: Dict[str, Any] = {}
|
||||
extra_processor_kwargs: Dict[str, Any] = {}
|
||||
|
||||
use_kv_cache: bool = True
|
||||
max_new_tokens: int = 4096
|
||||
|
||||
@@ -12,6 +12,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
InlineVlmOptions,
|
||||
ResponseFormat,
|
||||
TransformersModelType,
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@@ -26,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions(
|
||||
supported_devices=[AcceleratorDevice.MPS],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
stop_strings=["</doctag>", "<end_of_utterance>"],
|
||||
)
|
||||
|
||||
SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
|
||||
@@ -33,16 +35,74 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
|
||||
prompt="Convert this page to docling.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
],
|
||||
torch_dtype="bfloat16",
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
stop_strings=["</doctag>", "<end_of_utterance>"],
|
||||
)
|
||||
|
||||
SMOLDOCLING_VLLM = InlineVlmOptions(
|
||||
repo_id="ds4sd/SmolDocling-256M-preview",
|
||||
prompt="Convert this page to docling.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.VLLM,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CUDA,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
stop_strings=["</doctag>", "<end_of_utterance>"],
|
||||
)
|
||||
|
||||
# SmolVLM-256M-Instruct
|
||||
SMOLVLM256_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
|
||||
prompt="Transcribe this image to plain text.",
|
||||
response_format=ResponseFormat.PLAINTEXT,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
# AcceleratorDevice.MPS,
|
||||
],
|
||||
torch_dtype="bfloat16",
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# SmolVLM2-2.2b-Instruct
|
||||
SMOLVLM256_MLX = InlineVlmOptions(
|
||||
repo_id="moot20/SmolVLM-256M-Instruct-MLX",
|
||||
prompt="Extract the text.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.MLX,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
SMOLVLM256_VLLM = InlineVlmOptions(
|
||||
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
|
||||
prompt="Transcribe this image to plain text.",
|
||||
response_format=ResponseFormat.PLAINTEXT,
|
||||
inference_framework=InferenceFramework.VLLM,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CUDA,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
# GraniteVision
|
||||
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.2-2b",
|
||||
@@ -59,6 +119,18 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
GRANITE_VISION_VLLM = InlineVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.2-2b",
|
||||
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
inference_framework=InferenceFramework.VLLM,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CUDA,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
GRANITE_VISION_OLLAMA = ApiVlmOptions(
|
||||
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
|
||||
params={"model": "granite3.2-vision:2b"},
|
||||
@@ -116,6 +188,26 @@ QWEN25_VL_3B_MLX = InlineVlmOptions(
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# GoT 2.0
|
||||
GOT2_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="stepfun-ai/GOT-OCR-2.0-hf",
|
||||
prompt="",
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||
transformers_prompt_style=TransformersPromptStyle.NONE,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
# AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
stop_strings=["<|im_end|>"],
|
||||
extra_processor_kwargs={"format": True},
|
||||
)
|
||||
|
||||
|
||||
# Gemma-3
|
||||
GEMMA3_12B_MLX = InlineVlmOptions(
|
||||
repo_id="mlx-community/gemma-3-12b-it-bf16",
|
||||
@@ -137,8 +229,29 @@ GEMMA3_27B_MLX = InlineVlmOptions(
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# Dolphin
|
||||
|
||||
DOLPHIN_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="ByteDance/Dolphin",
|
||||
prompt="<s>Read text in the image. <Answer/>",
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
transformers_prompt_style=TransformersPromptStyle.RAW,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CUDA,
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
class VlmModelType(str, Enum):
|
||||
SMOLDOCLING = "smoldocling"
|
||||
SMOLDOCLING_VLLM = "smoldocling_vllm"
|
||||
GRANITE_VISION = "granite_vision"
|
||||
GRANITE_VISION_VLLM = "granite_vision_vllm"
|
||||
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
||||
GOT_OCR_2 = "got_ocr_2"
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Generic, Optional, Protocol, Type
|
||||
from typing import Any, Generic, Optional, Protocol, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
||||
from PIL.Image import Image
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
||||
from docling.datamodel.base_models import (
|
||||
ItemAndImageEnrichmentElement,
|
||||
Page,
|
||||
VlmPrediction,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import BaseOptions
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
InlineVlmOptions,
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
|
||||
|
||||
@@ -26,6 +37,88 @@ class BasePageModel(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class BaseVlmModel(ABC):
|
||||
"""Base class for Vision-Language Models that adds image processing capability."""
|
||||
|
||||
@abstractmethod
|
||||
def process_images(
|
||||
self,
|
||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||
prompt: Union[str, list[str]],
|
||||
) -> Iterable[VlmPrediction]:
|
||||
"""Process raw images without page metadata.
|
||||
|
||||
Args:
|
||||
image_batch: Iterable of PIL Images or numpy arrays
|
||||
prompt: Either:
|
||||
- str: Single prompt used for all images
|
||||
- list[str]: List of prompts (one per image, must match image count)
|
||||
|
||||
Raises:
|
||||
ValueError: If prompt list length doesn't match image count.
|
||||
"""
|
||||
|
||||
|
||||
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
||||
"""Base implementation for VLM models that inherit from BasePageModel.
|
||||
|
||||
Provides a default __call__ implementation that extracts images from pages,
|
||||
processes them using process_images, and attaches results back to pages.
|
||||
"""
|
||||
|
||||
# Type annotations for attributes that subclasses must initialize
|
||||
vlm_options: InlineVlmOptions
|
||||
processor: Any
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
"""Extract images from pages, process them, and attach results back."""
|
||||
|
||||
def formulate_prompt(self, user_prompt: str) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||
return user_prompt
|
||||
|
||||
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||
_log.debug("Using specialized prompt for Phi-4")
|
||||
# Note: This might need adjustment for VLLM vs transformers
|
||||
user_prompt_prefix = "<|user|>"
|
||||
assistant_prompt = "<|assistant|>"
|
||||
prompt_suffix = "<|end|>"
|
||||
|
||||
prompt = f"{user_prompt_prefix}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||
|
||||
return prompt
|
||||
|
||||
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": user_prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
return prompt
|
||||
|
||||
raise RuntimeError(
|
||||
f"Unknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||
)
|
||||
|
||||
|
||||
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
|
||||
|
||||
class PagePreprocessingOptions(BaseModel):
|
||||
images_scale: Optional[float]
|
||||
skip_cell_extraction: bool = (
|
||||
False # Skip text cell extraction for VLM-only processing
|
||||
)
|
||||
|
||||
|
||||
class PagePreprocessingModel(BasePageModel):
|
||||
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
|
||||
else:
|
||||
with TimeRecorder(conv_res, "page_parse"):
|
||||
page = self._populate_page_images(page)
|
||||
page = self._parse_page_cells(conv_res, page)
|
||||
if not self.options.skip_cell_extraction:
|
||||
page = self._parse_page_cells(conv_res, page)
|
||||
yield page
|
||||
|
||||
# Generate the page image and store it in the page object
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||
from docling.datamodel.pipeline_options import (
|
||||
@@ -63,7 +64,7 @@ class PictureDescriptionVlmModel(
|
||||
# Initialize processor and model
|
||||
with _model_init_lock:
|
||||
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
@@ -71,9 +72,10 @@ class PictureDescriptionVlmModel(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
and accelerator_options.cuda_use_flash_attention2
|
||||
else "eager"
|
||||
else "sdpa"
|
||||
),
|
||||
)
|
||||
self.model = torch.compile(self.model) # type: ignore
|
||||
|
||||
self.provenance = f"{self.options.repo_id}"
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
|
||||
@@ -3,7 +3,11 @@ import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from transformers import StoppingCriteriaList, StopStringCriteria
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
@@ -15,7 +19,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
TransformersModelType,
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.base_model import BaseVlmPageModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@@ -25,7 +29,7 @@ from docling.utils.profiling import TimeRecorder
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@@ -103,6 +107,8 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
artifacts_path,
|
||||
trust_remote_code=vlm_options.trust_remote_code,
|
||||
)
|
||||
self.processor.tokenizer.padding_side = "left"
|
||||
|
||||
self.vlm_model = model_cls.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
@@ -111,10 +117,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
and accelerator_options.cuda_use_flash_attention2
|
||||
else "eager"
|
||||
else "sdpa"
|
||||
),
|
||||
trust_remote_code=vlm_options.trust_remote_code,
|
||||
)
|
||||
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
||||
|
||||
# Load generation config
|
||||
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
||||
@@ -122,93 +129,186 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
page_list = list(page_batch)
|
||||
if not page_list:
|
||||
return
|
||||
|
||||
valid_pages = []
|
||||
invalid_pages = []
|
||||
|
||||
for page in page_list:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
invalid_pages.append(page)
|
||||
else:
|
||||
with TimeRecorder(conv_res, "vlm"):
|
||||
assert page.size is not None
|
||||
valid_pages.append(page)
|
||||
|
||||
# Process valid pages in batch
|
||||
if valid_pages:
|
||||
with TimeRecorder(conv_res, "vlm"):
|
||||
# Prepare images and prompts for batch processing
|
||||
images = []
|
||||
user_prompts = []
|
||||
pages_with_images = []
|
||||
|
||||
for page in valid_pages:
|
||||
assert page.size is not None
|
||||
hi_res_image = page.get_image(
|
||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||
)
|
||||
|
||||
# Define prompt structure
|
||||
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||
prompt = self.formulate_prompt(user_prompt)
|
||||
# Only process pages with valid images
|
||||
if hi_res_image is not None:
|
||||
images.append(hi_res_image)
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
||||
).to(self.device)
|
||||
# Define prompt structure
|
||||
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||
|
||||
start_time = time.time()
|
||||
# Call model to generate:
|
||||
generated_ids = self.vlm_model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
use_cache=self.use_cache,
|
||||
temperature=self.temperature,
|
||||
generation_config=self.generation_config,
|
||||
**self.vlm_options.extra_generation_config,
|
||||
)
|
||||
user_prompts.append(user_prompt)
|
||||
pages_with_images.append(page)
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
generated_texts = self.processor.batch_decode(
|
||||
generated_ids[:, inputs["input_ids"].shape[1] :],
|
||||
skip_special_tokens=False,
|
||||
)[0]
|
||||
# Use process_images for the actual inference
|
||||
if images: # Only if we have valid images
|
||||
predictions = list(self.process_images(images, user_prompts))
|
||||
|
||||
num_tokens = len(generated_ids[0])
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
)
|
||||
generated_texts = self.vlm_options.decode_response(generated_texts)
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=generated_texts,
|
||||
generation_time=generation_time,
|
||||
)
|
||||
# Attach results to pages
|
||||
for page, prediction in zip(pages_with_images, predictions):
|
||||
page.predictions.vlm_response = prediction
|
||||
|
||||
yield page
|
||||
# Yield all pages (valid and invalid)
|
||||
for page in invalid_pages:
|
||||
yield page
|
||||
for page in valid_pages:
|
||||
yield page
|
||||
|
||||
def formulate_prompt(self, user_prompt: str) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
def process_images(
|
||||
self,
|
||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||
prompt: Union[str, list[str]],
|
||||
) -> Iterable[VlmPrediction]:
|
||||
"""
|
||||
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
|
||||
- Lets the processor handle all padding & batching for text+images.
|
||||
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
|
||||
- Keeps your formulate_prompt() exactly as-is.
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
|
||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||
return user_prompt
|
||||
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
|
||||
pil_images: list[Image] = []
|
||||
for img in image_batch:
|
||||
if isinstance(img, np.ndarray):
|
||||
if img.ndim == 3 and img.shape[2] in (3, 4):
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||
elif img.ndim == 2:
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||
else:
|
||||
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||
else:
|
||||
pil_img = img
|
||||
if pil_img.mode != "RGB":
|
||||
pil_img = pil_img.convert("RGB")
|
||||
pil_images.append(pil_img)
|
||||
|
||||
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||
_log.debug("Using specialized prompt for Phi-4")
|
||||
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
||||
if not pil_images:
|
||||
return
|
||||
|
||||
user_prompt = "<|user|>"
|
||||
assistant_prompt = "<|assistant|>"
|
||||
prompt_suffix = "<|end|>"
|
||||
# -- Normalize prompts (1 per image)
|
||||
if isinstance(prompt, str):
|
||||
user_prompts = [prompt] * len(pil_images)
|
||||
else:
|
||||
if len(prompt) != len(pil_images):
|
||||
raise ValueError(
|
||||
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||
)
|
||||
user_prompts = prompt
|
||||
|
||||
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||
|
||||
return prompt
|
||||
|
||||
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": user_prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=False
|
||||
# Use your prompt formatter verbatim
|
||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
|
||||
inputs = self.processor(
|
||||
pil_images,
|
||||
return_tensors="pt",
|
||||
padding=True, # pad across batch for both text and vision
|
||||
**self.vlm_options.extra_processor_kwargs,
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
|
||||
|
||||
raise RuntimeError(
|
||||
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
|
||||
inputs = self.processor(
|
||||
text=prompts,
|
||||
images=pil_images,
|
||||
return_tensors="pt",
|
||||
padding=True, # pad across batch for both text and vision
|
||||
**self.vlm_options.extra_processor_kwargs,
|
||||
)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# -- Optional stopping criteria
|
||||
stopping_criteria = None
|
||||
if self.vlm_options.stop_strings:
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
[
|
||||
StopStringCriteria(
|
||||
stop_strings=self.vlm_options.stop_strings,
|
||||
tokenizer=self.processor.tokenizer,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
||||
gen_kwargs = {
|
||||
**inputs,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"use_cache": self.use_cache,
|
||||
"generation_config": self.generation_config,
|
||||
**self.vlm_options.extra_generation_config,
|
||||
}
|
||||
if self.temperature > 0:
|
||||
gen_kwargs["do_sample"] = True
|
||||
gen_kwargs["temperature"] = self.temperature
|
||||
else:
|
||||
gen_kwargs["do_sample"] = False
|
||||
|
||||
if stopping_criteria is not None:
|
||||
gen_kwargs["stopping_criteria"] = stopping_criteria
|
||||
|
||||
start_time = time.time()
|
||||
with torch.inference_mode():
|
||||
generated_ids = self.vlm_model.generate(**gen_kwargs)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
input_len = inputs["input_ids"].shape[1] # common right-aligned prompt length
|
||||
trimmed_sequences = generated_ids[:, input_len:] # only newly generated tokens
|
||||
|
||||
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
|
||||
decode_fn = getattr(self.processor, "batch_decode", None)
|
||||
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
|
||||
decode_fn = self.processor.tokenizer.batch_decode
|
||||
if decode_fn is None:
|
||||
raise RuntimeError(
|
||||
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
|
||||
)
|
||||
|
||||
decoded_texts: list[str] = decode_fn(
|
||||
trimmed_sequences, skip_special_tokens=False
|
||||
)
|
||||
|
||||
# -- Clip off pad tokens from decoded texts
|
||||
pad_token = self.processor.tokenizer.pad_token
|
||||
if pad_token:
|
||||
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
||||
|
||||
# -- Optional logging
|
||||
if generated_ids.shape[0] > 0:
|
||||
_log.debug(
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}."
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.base_model import BaseVlmPageModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
# Global lock for MLX model calls - MLX models are not thread-safe
|
||||
# All MLX models share this lock to prevent concurrent MLX operations
|
||||
_MLX_GLOBAL_LOCK = threading.Lock()
|
||||
|
||||
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
|
||||
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@@ -63,87 +71,190 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
page_list = list(page_batch)
|
||||
if not page_list:
|
||||
return
|
||||
|
||||
valid_pages = []
|
||||
invalid_pages = []
|
||||
|
||||
for page in page_list:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
invalid_pages.append(page)
|
||||
else:
|
||||
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
||||
assert page.size is not None
|
||||
valid_pages.append(page)
|
||||
|
||||
# Process valid pages in batch
|
||||
if valid_pages:
|
||||
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
||||
# Prepare images and prompts for batch processing
|
||||
images = []
|
||||
user_prompts = []
|
||||
pages_with_images = []
|
||||
|
||||
for page in valid_pages:
|
||||
assert page.size is not None
|
||||
hi_res_image = page.get_image(
|
||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||
)
|
||||
|
||||
# Only process pages with valid images
|
||||
if hi_res_image is not None:
|
||||
im_width, im_height = hi_res_image.size
|
||||
images.append(hi_res_image)
|
||||
|
||||
# populate page_tags with predicted doc tags
|
||||
page_tags = ""
|
||||
|
||||
if hi_res_image:
|
||||
if hi_res_image.mode != "RGB":
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
|
||||
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||
prompt = self.apply_chat_template(
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
_log.debug("start generating ...")
|
||||
|
||||
# Call model to generate:
|
||||
tokens: list[VlmPredictionToken] = []
|
||||
|
||||
output = ""
|
||||
for token in self.stream_generate(
|
||||
self.vlm_model,
|
||||
self.processor,
|
||||
prompt,
|
||||
[hi_res_image],
|
||||
max_tokens=self.max_tokens,
|
||||
verbose=False,
|
||||
temp=self.temperature,
|
||||
):
|
||||
if len(token.logprobs.shape) == 1:
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[token.token],
|
||||
)
|
||||
)
|
||||
elif (
|
||||
len(token.logprobs.shape) == 2
|
||||
and token.logprobs.shape[0] == 1
|
||||
):
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[0, token.token],
|
||||
)
|
||||
)
|
||||
# Define prompt structure
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
_log.warning(
|
||||
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
||||
)
|
||||
user_prompt = self.vlm_options.prompt
|
||||
|
||||
output += token.text
|
||||
if "</doctag>" in token.text:
|
||||
user_prompts.append(user_prompt)
|
||||
pages_with_images.append(page)
|
||||
|
||||
# Use process_images for the actual inference
|
||||
if images: # Only if we have valid images
|
||||
predictions = list(self.process_images(images, user_prompts))
|
||||
|
||||
# Attach results to pages
|
||||
for page, prediction in zip(pages_with_images, predictions):
|
||||
page.predictions.vlm_response = prediction
|
||||
|
||||
# Yield all pages (valid and invalid)
|
||||
for page in invalid_pages:
|
||||
yield page
|
||||
for page in valid_pages:
|
||||
yield page
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||
prompt: Union[str, list[str]],
|
||||
) -> Iterable[VlmPrediction]:
|
||||
"""Process raw images without page metadata.
|
||||
|
||||
Args:
|
||||
image_batch: Iterable of PIL Images or numpy arrays
|
||||
prompt: Either:
|
||||
- str: Single prompt used for all images
|
||||
- list[str]: List of prompts (one per image, must match image count)
|
||||
|
||||
Raises:
|
||||
ValueError: If prompt list length doesn't match image count.
|
||||
"""
|
||||
# Convert image batch to list for length validation
|
||||
image_list = list(image_batch)
|
||||
|
||||
if len(image_list) == 0:
|
||||
return
|
||||
|
||||
# Handle prompt parameter
|
||||
if isinstance(prompt, str):
|
||||
# Single prompt for all images
|
||||
user_prompts = [prompt] * len(image_list)
|
||||
elif isinstance(prompt, list):
|
||||
# List of prompts (one per image)
|
||||
if len(prompt) != len(image_list):
|
||||
raise ValueError(
|
||||
f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
|
||||
)
|
||||
user_prompts = prompt
|
||||
else:
|
||||
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
||||
|
||||
# MLX models are not thread-safe - use global lock to serialize access
|
||||
with _MLX_GLOBAL_LOCK:
|
||||
_log.debug("MLX model: Acquired global lock for thread safety")
|
||||
for image, user_prompt in zip(image_list, user_prompts):
|
||||
# Convert numpy array to PIL Image if needed
|
||||
if isinstance(image, np.ndarray):
|
||||
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
||||
# RGB or RGBA array
|
||||
from PIL import Image as PILImage
|
||||
|
||||
image = PILImage.fromarray(image.astype(np.uint8))
|
||||
elif image.ndim == 2:
|
||||
# Grayscale array
|
||||
from PIL import Image as PILImage
|
||||
|
||||
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported numpy array shape: {image.shape}"
|
||||
)
|
||||
|
||||
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Use the MLX chat template approach like in the __call__ method
|
||||
formatted_prompt = self.apply_chat_template(
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
|
||||
# Stream generate with stop strings support
|
||||
start_time = time.time()
|
||||
_log.debug("start generating ...")
|
||||
|
||||
tokens: list[VlmPredictionToken] = []
|
||||
output = ""
|
||||
|
||||
# Use stream_generate for proper stop string handling
|
||||
for token in self.stream_generate(
|
||||
self.vlm_model,
|
||||
self.processor,
|
||||
formatted_prompt,
|
||||
[image], # MLX stream_generate expects list of images
|
||||
max_tokens=self.max_tokens,
|
||||
verbose=False,
|
||||
temp=self.temperature,
|
||||
):
|
||||
# Collect token information
|
||||
if len(token.logprobs.shape) == 1:
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[token.token],
|
||||
)
|
||||
)
|
||||
elif (
|
||||
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
|
||||
):
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[0, token.token],
|
||||
)
|
||||
)
|
||||
else:
|
||||
_log.warning(
|
||||
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
||||
)
|
||||
|
||||
output += token.text
|
||||
|
||||
# Check for any configured stop strings
|
||||
if self.vlm_options.stop_strings:
|
||||
if any(
|
||||
stop_str in output
|
||||
for stop_str in self.vlm_options.stop_strings
|
||||
):
|
||||
_log.debug("Stopping generation due to stop string match")
|
||||
break
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
page_tags = output
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
_log.debug(
|
||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
||||
)
|
||||
page_tags = self.vlm_options.decode_response(page_tags)
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=page_tags,
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens,
|
||||
)
|
||||
_log.debug(
|
||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
|
||||
)
|
||||
|
||||
yield page
|
||||
# Apply decode_response to the output before yielding
|
||||
decoded_output = self.vlm_options.decode_response(output)
|
||||
yield VlmPrediction(
|
||||
text=decoded_output,
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens,
|
||||
)
|
||||
_log.debug("MLX model: Released global lock")
|
||||
|
||||
235
docling/models/vlm_models_inline/vllm_model.py
Normal file
235
docling/models/vlm_models_inline/vllm_model.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
InlineVlmOptions,
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
from docling.models.base_model import BaseVlmPageModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Path],
|
||||
accelerator_options: AcceleratorOptions,
|
||||
vlm_options: InlineVlmOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
if self.enabled:
|
||||
from transformers import AutoProcessor
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
self.device = decide_device(
|
||||
accelerator_options.device,
|
||||
supported_devices=vlm_options.supported_devices,
|
||||
)
|
||||
_log.debug(f"Available device for VLM: {self.device}")
|
||||
|
||||
self.max_new_tokens = vlm_options.max_new_tokens
|
||||
self.temperature = vlm_options.temperature
|
||||
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
# Initialize VLLM LLM
|
||||
llm_kwargs: Dict[str, Any] = {
|
||||
"model": str(artifacts_path),
|
||||
"limit_mm_per_prompt": {"image": 1},
|
||||
"trust_remote_code": vlm_options.trust_remote_code,
|
||||
"model_impl": "transformers",
|
||||
"gpu_memory_utilization": 0.3, # hardcoded for now, leaves room for ~3 different models.
|
||||
}
|
||||
|
||||
# Add device-specific configurations
|
||||
|
||||
if self.device == "cpu":
|
||||
llm_kwargs["device"] = "cpu"
|
||||
|
||||
# Add quantization if specified
|
||||
if vlm_options.quantized:
|
||||
if vlm_options.load_in_8bit:
|
||||
llm_kwargs["quantization"] = "bitsandbytes"
|
||||
|
||||
self.llm = LLM(**llm_kwargs)
|
||||
|
||||
# Initialize processor for prompt formatting
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
artifacts_path,
|
||||
trust_remote_code=vlm_options.trust_remote_code,
|
||||
)
|
||||
|
||||
# Set up sampling parameters
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_new_tokens,
|
||||
stop=vlm_options.stop_strings if vlm_options.stop_strings else None,
|
||||
**vlm_options.extra_generation_config,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
page_list = list(page_batch)
|
||||
if not page_list:
|
||||
return
|
||||
|
||||
valid_pages = []
|
||||
invalid_pages = []
|
||||
|
||||
for page in page_list:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
invalid_pages.append(page)
|
||||
else:
|
||||
valid_pages.append(page)
|
||||
|
||||
# Process valid pages in batch
|
||||
if valid_pages:
|
||||
with TimeRecorder(conv_res, "vlm"):
|
||||
# Prepare images and prompts for batch processing
|
||||
images = []
|
||||
user_prompts = []
|
||||
pages_with_images = []
|
||||
|
||||
for page in valid_pages:
|
||||
assert page.size is not None
|
||||
hi_res_image = page.get_image(
|
||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||
)
|
||||
|
||||
# Only process pages with valid images
|
||||
if hi_res_image is not None:
|
||||
images.append(hi_res_image)
|
||||
|
||||
# Define prompt structure
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
user_prompt = self.vlm_options.prompt
|
||||
|
||||
user_prompts.append(user_prompt)
|
||||
pages_with_images.append(page)
|
||||
|
||||
# Use process_images for the actual inference
|
||||
if images: # Only if we have valid images
|
||||
predictions = list(self.process_images(images, user_prompts))
|
||||
|
||||
# Attach results to pages
|
||||
for page, prediction in zip(pages_with_images, predictions):
|
||||
page.predictions.vlm_response = prediction
|
||||
|
||||
# Yield all pages (valid and invalid)
|
||||
for page in invalid_pages:
|
||||
yield page
|
||||
for page in valid_pages:
|
||||
yield page
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||
prompt: Union[str, list[str]],
|
||||
) -> Iterable[VlmPrediction]:
|
||||
"""Process raw images without page metadata in a single batched inference call.
|
||||
|
||||
Args:
|
||||
image_batch: Iterable of PIL Images or numpy arrays
|
||||
prompt: Either:
|
||||
- str: Single prompt used for all images
|
||||
- list[str]: List of prompts (one per image, must match image count)
|
||||
|
||||
Raises:
|
||||
ValueError: If prompt list length doesn't match image count.
|
||||
"""
|
||||
pil_images: list[Image] = []
|
||||
|
||||
for img in image_batch:
|
||||
# Convert numpy array to PIL Image if needed
|
||||
if isinstance(img, np.ndarray):
|
||||
if img.ndim == 3 and img.shape[2] in [3, 4]:
|
||||
from PIL import Image as PILImage
|
||||
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||
elif img.ndim == 2:
|
||||
from PIL import Image as PILImage
|
||||
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||
else:
|
||||
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||
else:
|
||||
pil_img = img
|
||||
|
||||
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
||||
if pil_img.mode != "RGB":
|
||||
pil_img = pil_img.convert("RGB")
|
||||
|
||||
pil_images.append(pil_img)
|
||||
|
||||
if len(pil_images) == 0:
|
||||
return
|
||||
|
||||
# Handle prompt parameter
|
||||
if isinstance(prompt, str):
|
||||
# Single prompt for all images
|
||||
user_prompts = [prompt] * len(pil_images)
|
||||
elif isinstance(prompt, list):
|
||||
# List of prompts (one per image)
|
||||
if len(prompt) != len(pil_images):
|
||||
raise ValueError(
|
||||
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||
)
|
||||
user_prompts = prompt
|
||||
else:
|
||||
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
||||
|
||||
# Format prompts individually
|
||||
prompts: list[str] = [
|
||||
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
|
||||
]
|
||||
|
||||
# Prepare VLLM inputs
|
||||
llm_inputs = []
|
||||
for prompt, image in zip(prompts, pil_images):
|
||||
llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
|
||||
|
||||
start_time = time.time()
|
||||
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Logging tokens count for the first sample as a representative metric
|
||||
if len(outputs) > 0:
|
||||
num_tokens = len(outputs[0].outputs[0].token_ids)
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
)
|
||||
|
||||
for output in outputs:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(output.outputs[0].text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
@@ -194,7 +194,7 @@ class ThreadedPipelineStage:
|
||||
return
|
||||
self._running = True
|
||||
self._thread = threading.Thread(
|
||||
target=self._run, name=f"Stage-{self.name}", daemon=False
|
||||
target=self._run, name=f"Stage-{self.name}", daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
|
||||
@@ -103,6 +103,17 @@ class VlmPipeline(PaginatedPipeline):
|
||||
vlm_options=vlm_options,
|
||||
),
|
||||
]
|
||||
elif vlm_options.inference_framework == InferenceFramework.VLLM:
|
||||
from docling.models.vlm_models_inline.vllm_model import VllmVlmModel
|
||||
|
||||
self.build_pipe = [
|
||||
VllmVlmModel(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
vlm_options=vlm_options,
|
||||
),
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||
@@ -117,7 +128,9 @@ class VlmPipeline(PaginatedPipeline):
|
||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||
if page._backend is not None and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
page.parsed_page = page._backend.get_segmented_page()
|
||||
|
||||
if self.force_backend_text:
|
||||
page.parsed_page = page._backend.get_segmented_page()
|
||||
|
||||
return page
|
||||
|
||||
|
||||
@@ -239,15 +239,18 @@ class LayoutPostprocessor:
|
||||
final_clusters = self._sort_clusters(
|
||||
self.regular_clusters + self.special_clusters, mode="id"
|
||||
)
|
||||
for cluster in final_clusters:
|
||||
cluster.cells = self._sort_cells(cluster.cells)
|
||||
# Also sort cells in children if any
|
||||
for child in cluster.children:
|
||||
child.cells = self._sort_cells(child.cells)
|
||||
|
||||
assert self.page.parsed_page is not None
|
||||
self.page.parsed_page.textline_cells = self.cells
|
||||
self.page.parsed_page.has_lines = len(self.cells) > 0
|
||||
# Conditionally process cells if not skipping cell assignment
|
||||
if not self.options.skip_cell_assignment:
|
||||
for cluster in final_clusters:
|
||||
cluster.cells = self._sort_cells(cluster.cells)
|
||||
# Also sort cells in children if any
|
||||
for child in cluster.children:
|
||||
child.cells = self._sort_cells(child.cells)
|
||||
|
||||
assert self.page.parsed_page is not None
|
||||
self.page.parsed_page.textline_cells = self.cells
|
||||
self.page.parsed_page.has_lines = len(self.cells) > 0
|
||||
|
||||
return final_clusters, self.cells
|
||||
|
||||
@@ -264,36 +267,38 @@ class LayoutPostprocessor:
|
||||
if cluster.label in self.LABEL_REMAPPING:
|
||||
cluster.label = self.LABEL_REMAPPING[cluster.label]
|
||||
|
||||
# Initial cell assignment
|
||||
clusters = self._assign_cells_to_clusters(clusters)
|
||||
# Conditionally assign cells to clusters
|
||||
if not self.options.skip_cell_assignment:
|
||||
# Initial cell assignment
|
||||
clusters = self._assign_cells_to_clusters(clusters)
|
||||
|
||||
# Remove clusters with no cells (if keep_empty_clusters is False),
|
||||
# but always keep clusters with label DocItemLabel.FORMULA
|
||||
if not self.options.keep_empty_clusters:
|
||||
clusters = [
|
||||
cluster
|
||||
for cluster in clusters
|
||||
if cluster.cells or cluster.label == DocItemLabel.FORMULA
|
||||
]
|
||||
# Remove clusters with no cells (if keep_empty_clusters is False),
|
||||
# but always keep clusters with label DocItemLabel.FORMULA
|
||||
if not self.options.keep_empty_clusters:
|
||||
clusters = [
|
||||
cluster
|
||||
for cluster in clusters
|
||||
if cluster.cells or cluster.label == DocItemLabel.FORMULA
|
||||
]
|
||||
|
||||
# Handle orphaned cells
|
||||
unassigned = self._find_unassigned_cells(clusters)
|
||||
if unassigned and self.options.create_orphan_clusters:
|
||||
next_id = max((c.id for c in self.all_clusters), default=0) + 1
|
||||
orphan_clusters = []
|
||||
for i, cell in enumerate(unassigned):
|
||||
conf = cell.confidence
|
||||
# Handle orphaned cells
|
||||
unassigned = self._find_unassigned_cells(clusters)
|
||||
if unassigned and self.options.create_orphan_clusters:
|
||||
next_id = max((c.id for c in self.all_clusters), default=0) + 1
|
||||
orphan_clusters = []
|
||||
for i, cell in enumerate(unassigned):
|
||||
conf = cell.confidence
|
||||
|
||||
orphan_clusters.append(
|
||||
Cluster(
|
||||
id=next_id + i,
|
||||
label=DocItemLabel.TEXT,
|
||||
bbox=cell.to_bounding_box(),
|
||||
confidence=conf,
|
||||
cells=[cell],
|
||||
orphan_clusters.append(
|
||||
Cluster(
|
||||
id=next_id + i,
|
||||
label=DocItemLabel.TEXT,
|
||||
bbox=cell.to_bounding_box(),
|
||||
confidence=conf,
|
||||
cells=[cell],
|
||||
)
|
||||
)
|
||||
)
|
||||
clusters.extend(orphan_clusters)
|
||||
clusters.extend(orphan_clusters)
|
||||
|
||||
# Iterative refinement
|
||||
prev_count = len(clusters) + 1
|
||||
@@ -350,12 +355,15 @@ class LayoutPostprocessor:
|
||||
b=max(c.bbox.b for c in contained),
|
||||
)
|
||||
|
||||
# Collect all cells from children
|
||||
all_cells = []
|
||||
for child in contained:
|
||||
all_cells.extend(child.cells)
|
||||
special.cells = self._deduplicate_cells(all_cells)
|
||||
special.cells = self._sort_cells(special.cells)
|
||||
# Conditionally collect cells from children
|
||||
if not self.options.skip_cell_assignment:
|
||||
all_cells = []
|
||||
for child in contained:
|
||||
all_cells.extend(child.cells)
|
||||
special.cells = self._deduplicate_cells(all_cells)
|
||||
special.cells = self._sort_cells(special.cells)
|
||||
else:
|
||||
special.cells = []
|
||||
|
||||
picture_clusters = [
|
||||
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
||||
|
||||
@@ -93,6 +93,7 @@ vlm = [
|
||||
'transformers (>=4.46.0,<5.0.0)',
|
||||
'accelerate (>=1.2.1,<2.0.0)',
|
||||
'mlx-vlm (>=0.3.0,<1.0.0) ; python_version >= "3.10" and sys_platform == "darwin" and platform_machine == "arm64"',
|
||||
'vllm (>=0.10.0,<1.0.0) ; python_version >= "3.10" and sys_platform == "linux"',
|
||||
]
|
||||
rapidocr = [
|
||||
'rapidocr-onnxruntime (>=1.4.0,<2.0.0) ; python_version < "3.13"',
|
||||
@@ -252,6 +253,7 @@ module = [
|
||||
"huggingface_hub.*",
|
||||
"transformers.*",
|
||||
"pylatexenc.*",
|
||||
"vllm.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
|
||||
Reference in New Issue
Block a user