Prepare existing codes for use with new multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-13 14:02:19 +02:00
parent e2cca931be
commit 126944c7ee
7 changed files with 345 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
import math
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from docling_core.types.doc import (

View File

@@ -282,6 +282,9 @@ class LayoutOptions(BaseModel):
keep_empty_clusters: bool = (
False # Whether to keep clusters that contain no text cells
)
skip_cell_assignment: bool = (
False # Skip cell-to-cluster assignment for VLM-only processing
)
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2

View File

@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Generic, Optional, Protocol, Type
from typing import Generic, Optional, Protocol, Type, Union
import numpy as np
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from PIL.Image import Image
from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
from docling.datamodel.base_models import (
ItemAndImageEnrichmentElement,
Page,
VlmPrediction,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions
from docling.datamodel.settings import settings
@@ -26,6 +32,46 @@ class BasePageModel(ABC):
pass
class BaseVlmModel(ABC):
"""Base class for Vision-Language Models that adds image processing capability."""
@abstractmethod
def process_images(
self, image_batch: Iterable[Union[Image, np.ndarray]]
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata."""
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
"""Base implementation for VLM models that inherit from BasePageModel.
Provides a default __call__ implementation that extracts images from pages,
processes them using process_images, and attaches results back to pages.
"""
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
"""Extract images from pages, process them, and attach results back."""
@abstractmethod
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string.
If vlm_options.prompt is callable and no prompt is provided, raises ValueError.
Raises:
ValueError: If vlm_options.prompt is callable and no prompt parameter is provided.
"""
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
class PagePreprocessingOptions(BaseModel):
images_scale: Optional[float]
skip_cell_extraction: bool = (
False # Skip text cell extraction for VLM-only processing
)
class PagePreprocessingModel(BasePageModel):
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
else:
with TimeRecorder(conv_res, "page_parse"):
page = self._populate_page_images(page)
page = self._parse_page_cells(conv_res, page)
if not self.options.skip_cell_extraction:
page = self._parse_page_cells(conv_res, page)
yield page
# Generate the page image and store it in the page object

View File

@@ -3,7 +3,10 @@ import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -15,7 +18,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType,
TransformersPromptStyle,
)
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -25,7 +28,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@@ -159,7 +162,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
skip_special_tokens=True,
)[0]
num_tokens = len(generated_ids[0])
@@ -214,3 +217,101 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
raise RuntimeError(
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
)
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call."""
pil_images: list[Image] = []
for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if len(pil_images) == 0:
return
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
formatted_prompt = self.formulate_prompt(user_prompt)
prompts: list[str] = [formatted_prompt] * len(pil_images)
inputs = self.processor(
text=prompts, images=pil_images, return_tensors="pt", padding=True
).to(self.device)
start_time = time.time()
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generation_time = time.time() - start_time
# Determine per-sample prompt lengths
try:
attention_mask = inputs["attention_mask"]
input_lengths: list[int] = attention_mask.sum(dim=1).tolist()
except KeyError:
tokenizer = (
self.processor.tokenizer
) # Expect tokenizer to be present when text is provided
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None:
input_lengths = (
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
)
else:
# Fallback: assume uniform prompt length (least accurate but preserves execution)
uniform_len = int(inputs["input_ids"].shape[1])
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
trimmed_sequences: list[list[int]] = [
generated_ids[i, int(input_lengths[i]) :].tolist()
for i in range(generated_ids.shape[0])
]
decoded_texts: list[str] = self.processor.batch_decode(
trimmed_sequences, skip_special_tokens=True
)
# Logging tokens count for the first sample as a representative metric
if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
for text in decoded_texts:
yield VlmPrediction(text=text, generation_time=generation_time)

View File

@@ -1,8 +1,12 @@
import logging
import threading
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
# Global lock for MLX model calls - MLX models are not thread-safe
# All MLX models share this lock to prevent concurrent MLX operations
_MLX_GLOBAL_LOCK = threading.Lock()
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@@ -92,51 +100,57 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1
)
start_time = time.time()
_log.debug("start generating ...")
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug(
"MLX model: Acquired global lock for __call__ method"
)
start_time = time.time()
_log.debug("start generating ...")
# Call model to generate:
tokens: list[VlmPredictionToken] = []
# Call model to generate:
tokens: list[VlmPredictionToken] = []
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
if "</doctag>" in token.text:
break
output += token.text
if "</doctag>" in token.text:
break
generation_time = time.time() - start_time
generation_time = time.time() - start_time
_log.debug("MLX model: Released global lock")
page_tags = output
_log.debug(
@@ -149,3 +163,82 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
)
yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
from mlx_vlm import generate
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug("MLX model: Acquired global lock for thread safety")
for image in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3, 4]:
# RGB or RGBA array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8))
elif image.ndim == 2:
# Grayscale array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
else:
raise ValueError(
f"Unsupported numpy array shape: {image.shape}"
)
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
# Use the MLX chat template approach like in the __call__ method
formatted_prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
# Generate text from the image - MLX can accept PIL Images directly despite type annotations
start_time = time.time()
generated_result = generate(
self.vlm_model,
self.processor,
formatted_prompt,
image=image, # Pass PIL Image directly - much more efficient than disk I/O
verbose=False,
temp=self.temperature,
max_tokens=self.max_tokens,
)
generation_time = time.time() - start_time
# MLX generate returns a tuple (text, info_dict), extract just the text
if isinstance(generated_result, tuple):
generated_text = generated_result[0]
_log.debug(
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
)
else:
generated_text = generated_result
_log.debug(f"Generated text in {generation_time:.2f}s.")
yield VlmPrediction(
text=generated_text,
generation_time=generation_time,
# MLX generate doesn't expose tokens directly, so we leave it empty
generated_tokens=[],
)
_log.debug("MLX model: Released global lock")

View File

@@ -239,15 +239,18 @@ class LayoutPostprocessor:
final_clusters = self._sort_clusters(
self.regular_clusters + self.special_clusters, mode="id"
)
for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
for child in cluster.children:
child.cells = self._sort_cells(child.cells)
assert self.page.parsed_page is not None
self.page.parsed_page.textline_cells = self.cells
self.page.parsed_page.has_lines = len(self.cells) > 0
# Conditionally process cells if not skipping cell assignment
if not self.options.skip_cell_assignment:
for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
for child in cluster.children:
child.cells = self._sort_cells(child.cells)
assert self.page.parsed_page is not None
self.page.parsed_page.textline_cells = self.cells
self.page.parsed_page.has_lines = len(self.cells) > 0
return final_clusters, self.cells
@@ -264,36 +267,38 @@ class LayoutPostprocessor:
if cluster.label in self.LABEL_REMAPPING:
cluster.label = self.LABEL_REMAPPING[cluster.label]
# Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters)
# Conditionally assign cells to clusters
if not self.options.skip_cell_assignment:
# Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters)
# Remove clusters with no cells (if keep_empty_clusters is False),
# but always keep clusters with label DocItemLabel.FORMULA
if not self.options.keep_empty_clusters:
clusters = [
cluster
for cluster in clusters
if cluster.cells or cluster.label == DocItemLabel.FORMULA
]
# Remove clusters with no cells (if keep_empty_clusters is False),
# but always keep clusters with label DocItemLabel.FORMULA
if not self.options.keep_empty_clusters:
clusters = [
cluster
for cluster in clusters
if cluster.cells or cluster.label == DocItemLabel.FORMULA
]
# Handle orphaned cells
unassigned = self._find_unassigned_cells(clusters)
if unassigned and self.options.create_orphan_clusters:
next_id = max((c.id for c in self.all_clusters), default=0) + 1
orphan_clusters = []
for i, cell in enumerate(unassigned):
conf = cell.confidence
# Handle orphaned cells
unassigned = self._find_unassigned_cells(clusters)
if unassigned and self.options.create_orphan_clusters:
next_id = max((c.id for c in self.all_clusters), default=0) + 1
orphan_clusters = []
for i, cell in enumerate(unassigned):
conf = cell.confidence
orphan_clusters.append(
Cluster(
id=next_id + i,
label=DocItemLabel.TEXT,
bbox=cell.to_bounding_box(),
confidence=conf,
cells=[cell],
orphan_clusters.append(
Cluster(
id=next_id + i,
label=DocItemLabel.TEXT,
bbox=cell.to_bounding_box(),
confidence=conf,
cells=[cell],
)
)
)
clusters.extend(orphan_clusters)
clusters.extend(orphan_clusters)
# Iterative refinement
prev_count = len(clusters) + 1
@@ -350,12 +355,15 @@ class LayoutPostprocessor:
b=max(c.bbox.b for c in contained),
)
# Collect all cells from children
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
# Conditionally collect cells from children
if not self.options.skip_cell_assignment:
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
else:
special.cells = []
picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE