Prepare existing codes for use with new multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-13 14:02:19 +02:00
parent e2cca931be
commit 126944c7ee
7 changed files with 345 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
import math
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from docling_core.types.doc import (

View File

@@ -282,6 +282,9 @@ class LayoutOptions(BaseModel):
keep_empty_clusters: bool = (
False # Whether to keep clusters that contain no text cells
)
skip_cell_assignment: bool = (
False # Skip cell-to-cluster assignment for VLM-only processing
)
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2

View File

@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Generic, Optional, Protocol, Type
from typing import Generic, Optional, Protocol, Type, Union
import numpy as np
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from PIL.Image import Image
from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
from docling.datamodel.base_models import (
ItemAndImageEnrichmentElement,
Page,
VlmPrediction,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions
from docling.datamodel.settings import settings
@@ -26,6 +32,46 @@ class BasePageModel(ABC):
pass
class BaseVlmModel(ABC):
"""Base class for Vision-Language Models that adds image processing capability."""
@abstractmethod
def process_images(
self, image_batch: Iterable[Union[Image, np.ndarray]]
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata."""
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
"""Base implementation for VLM models that inherit from BasePageModel.
Provides a default __call__ implementation that extracts images from pages,
processes them using process_images, and attaches results back to pages.
"""
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
"""Extract images from pages, process them, and attach results back."""
@abstractmethod
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string.
If vlm_options.prompt is callable and no prompt is provided, raises ValueError.
Raises:
ValueError: If vlm_options.prompt is callable and no prompt parameter is provided.
"""
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
class PagePreprocessingOptions(BaseModel):
images_scale: Optional[float]
skip_cell_extraction: bool = (
False # Skip text cell extraction for VLM-only processing
)
class PagePreprocessingModel(BasePageModel):
@@ -41,6 +44,7 @@ class PagePreprocessingModel(BasePageModel):
else:
with TimeRecorder(conv_res, "page_parse"):
page = self._populate_page_images(page)
if not self.options.skip_cell_extraction:
page = self._parse_page_cells(conv_res, page)
yield page

View File

@@ -3,7 +3,10 @@ import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -15,7 +18,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType,
TransformersPromptStyle,
)
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -25,7 +28,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@@ -159,7 +162,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
skip_special_tokens=True,
)[0]
num_tokens = len(generated_ids[0])
@@ -214,3 +217,101 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
raise RuntimeError(
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
)
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call."""
pil_images: list[Image] = []
for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if len(pil_images) == 0:
return
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
formatted_prompt = self.formulate_prompt(user_prompt)
prompts: list[str] = [formatted_prompt] * len(pil_images)
inputs = self.processor(
text=prompts, images=pil_images, return_tensors="pt", padding=True
).to(self.device)
start_time = time.time()
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generation_time = time.time() - start_time
# Determine per-sample prompt lengths
try:
attention_mask = inputs["attention_mask"]
input_lengths: list[int] = attention_mask.sum(dim=1).tolist()
except KeyError:
tokenizer = (
self.processor.tokenizer
) # Expect tokenizer to be present when text is provided
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None:
input_lengths = (
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
)
else:
# Fallback: assume uniform prompt length (least accurate but preserves execution)
uniform_len = int(inputs["input_ids"].shape[1])
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
trimmed_sequences: list[list[int]] = [
generated_ids[i, int(input_lengths[i]) :].tolist()
for i in range(generated_ids.shape[0])
]
decoded_texts: list[str] = self.processor.batch_decode(
trimmed_sequences, skip_special_tokens=True
)
# Logging tokens count for the first sample as a representative metric
if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
for text in decoded_texts:
yield VlmPrediction(text=text, generation_time=generation_time)

View File

@@ -1,8 +1,12 @@
import logging
import threading
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
# Global lock for MLX model calls - MLX models are not thread-safe
# All MLX models share this lock to prevent concurrent MLX operations
_MLX_GLOBAL_LOCK = threading.Lock()
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@@ -92,6 +100,11 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1
)
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug(
"MLX model: Acquired global lock for __call__ method"
)
start_time = time.time()
_log.debug("start generating ...")
@@ -137,6 +150,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
break
generation_time = time.time() - start_time
_log.debug("MLX model: Released global lock")
page_tags = output
_log.debug(
@@ -149,3 +163,82 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
)
yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
from mlx_vlm import generate
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug("MLX model: Acquired global lock for thread safety")
for image in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3, 4]:
# RGB or RGBA array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8))
elif image.ndim == 2:
# Grayscale array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
else:
raise ValueError(
f"Unsupported numpy array shape: {image.shape}"
)
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
# Use the MLX chat template approach like in the __call__ method
formatted_prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
# Generate text from the image - MLX can accept PIL Images directly despite type annotations
start_time = time.time()
generated_result = generate(
self.vlm_model,
self.processor,
formatted_prompt,
image=image, # Pass PIL Image directly - much more efficient than disk I/O
verbose=False,
temp=self.temperature,
max_tokens=self.max_tokens,
)
generation_time = time.time() - start_time
# MLX generate returns a tuple (text, info_dict), extract just the text
if isinstance(generated_result, tuple):
generated_text = generated_result[0]
_log.debug(
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
)
else:
generated_text = generated_result
_log.debug(f"Generated text in {generation_time:.2f}s.")
yield VlmPrediction(
text=generated_text,
generation_time=generation_time,
# MLX generate doesn't expose tokens directly, so we leave it empty
generated_tokens=[],
)
_log.debug("MLX model: Released global lock")

View File

@@ -239,6 +239,9 @@ class LayoutPostprocessor:
final_clusters = self._sort_clusters(
self.regular_clusters + self.special_clusters, mode="id"
)
# Conditionally process cells if not skipping cell assignment
if not self.options.skip_cell_assignment:
for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
@@ -264,6 +267,8 @@ class LayoutPostprocessor:
if cluster.label in self.LABEL_REMAPPING:
cluster.label = self.LABEL_REMAPPING[cluster.label]
# Conditionally assign cells to clusters
if not self.options.skip_cell_assignment:
# Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters)
@@ -350,12 +355,15 @@ class LayoutPostprocessor:
b=max(c.bbox.b for c in contained),
)
# Collect all cells from children
# Conditionally collect cells from children
if not self.options.skip_cell_assignment:
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
else:
special.cells = []
picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE