Prepare existing codes for use with new multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-13 14:02:19 +02:00
parent e2cca931be
commit 126944c7ee
7 changed files with 345 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
import math import math
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np import numpy as np
from docling_core.types.doc import ( from docling_core.types.doc import (

View File

@@ -282,6 +282,9 @@ class LayoutOptions(BaseModel):
keep_empty_clusters: bool = ( keep_empty_clusters: bool = (
False # Whether to keep clusters that contain no text cells False # Whether to keep clusters that contain no text cells
) )
skip_cell_assignment: bool = (
False # Skip cell-to-cluster assignment for VLM-only processing
)
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2 model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2

View File

@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Generic, Optional, Protocol, Type from typing import Generic, Optional, Protocol, Type, Union
import numpy as np
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from PIL.Image import Image
from typing_extensions import TypeVar from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page from docling.datamodel.base_models import (
ItemAndImageEnrichmentElement,
Page,
VlmPrediction,
)
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions from docling.datamodel.pipeline_options import BaseOptions
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
@@ -26,6 +32,46 @@ class BasePageModel(ABC):
pass pass
class BaseVlmModel(ABC):
"""Base class for Vision-Language Models that adds image processing capability."""
@abstractmethod
def process_images(
self, image_batch: Iterable[Union[Image, np.ndarray]]
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata."""
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
"""Base implementation for VLM models that inherit from BasePageModel.
Provides a default __call__ implementation that extracts images from pages,
processes them using process_images, and attaches results back to pages.
"""
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
"""Extract images from pages, process them, and attach results back."""
@abstractmethod
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string.
If vlm_options.prompt is callable and no prompt is provided, raises ValueError.
Raises:
ValueError: If vlm_options.prompt is callable and no prompt parameter is provided.
"""
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem) EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
class PagePreprocessingOptions(BaseModel): class PagePreprocessingOptions(BaseModel):
images_scale: Optional[float] images_scale: Optional[float]
skip_cell_extraction: bool = (
False # Skip text cell extraction for VLM-only processing
)
class PagePreprocessingModel(BasePageModel): class PagePreprocessingModel(BasePageModel):
@@ -41,6 +44,7 @@ class PagePreprocessingModel(BasePageModel):
else: else:
with TimeRecorder(conv_res, "page_parse"): with TimeRecorder(conv_res, "page_parse"):
page = self._populate_page_images(page) page = self._populate_page_images(page)
if not self.options.skip_cell_extraction:
page = self._parse_page_cells(conv_res, page) page = self._parse_page_cells(conv_res, page)
yield page yield page

View File

@@ -3,7 +3,10 @@ import logging
import time import time
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
@@ -15,7 +18,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType, TransformersModelType,
TransformersPromptStyle, TransformersPromptStyle,
) )
from docling.models.base_model import BasePageModel from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@@ -25,7 +28,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
@@ -159,7 +162,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
generation_time = time.time() - start_time generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode( generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :], generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False, skip_special_tokens=True,
)[0] )[0]
num_tokens = len(generated_ids[0]) num_tokens = len(generated_ids[0])
@@ -214,3 +217,101 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
raise RuntimeError( raise RuntimeError(
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}." f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
) )
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call."""
pil_images: list[Image] = []
for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if len(pil_images) == 0:
return
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
formatted_prompt = self.formulate_prompt(user_prompt)
prompts: list[str] = [formatted_prompt] * len(pil_images)
inputs = self.processor(
text=prompts, images=pil_images, return_tensors="pt", padding=True
).to(self.device)
start_time = time.time()
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generation_time = time.time() - start_time
# Determine per-sample prompt lengths
try:
attention_mask = inputs["attention_mask"]
input_lengths: list[int] = attention_mask.sum(dim=1).tolist()
except KeyError:
tokenizer = (
self.processor.tokenizer
) # Expect tokenizer to be present when text is provided
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None:
input_lengths = (
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
)
else:
# Fallback: assume uniform prompt length (least accurate but preserves execution)
uniform_len = int(inputs["input_ids"].shape[1])
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
trimmed_sequences: list[list[int]] = [
generated_ids[i, int(input_lengths[i]) :].tolist()
for i in range(generated_ids.shape[0])
]
decoded_texts: list[str] = self.processor.batch_decode(
trimmed_sequences, skip_special_tokens=True
)
# Logging tokens count for the first sample as a representative metric
if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
for text in decoded_texts:
yield VlmPrediction(text=text, generation_time=generation_time)

View File

@@ -1,8 +1,12 @@
import logging import logging
import threading
import time import time
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
# Global lock for MLX model calls - MLX models are not thread-safe
# All MLX models share this lock to prevent concurrent MLX operations
_MLX_GLOBAL_LOCK = threading.Lock()
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
@@ -92,6 +100,11 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1 self.processor, self.config, user_prompt, num_images=1
) )
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug(
"MLX model: Acquired global lock for __call__ method"
)
start_time = time.time() start_time = time.time()
_log.debug("start generating ...") _log.debug("start generating ...")
@@ -137,6 +150,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
break break
generation_time = time.time() - start_time generation_time = time.time() - start_time
_log.debug("MLX model: Released global lock")
page_tags = output page_tags = output
_log.debug( _log.debug(
@@ -149,3 +163,82 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
) )
yield page yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
from mlx_vlm import generate
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug("MLX model: Acquired global lock for thread safety")
for image in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3, 4]:
# RGB or RGBA array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8))
elif image.ndim == 2:
# Grayscale array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
else:
raise ValueError(
f"Unsupported numpy array shape: {image.shape}"
)
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
# Use the MLX chat template approach like in the __call__ method
formatted_prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
# Generate text from the image - MLX can accept PIL Images directly despite type annotations
start_time = time.time()
generated_result = generate(
self.vlm_model,
self.processor,
formatted_prompt,
image=image, # Pass PIL Image directly - much more efficient than disk I/O
verbose=False,
temp=self.temperature,
max_tokens=self.max_tokens,
)
generation_time = time.time() - start_time
# MLX generate returns a tuple (text, info_dict), extract just the text
if isinstance(generated_result, tuple):
generated_text = generated_result[0]
_log.debug(
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
)
else:
generated_text = generated_result
_log.debug(f"Generated text in {generation_time:.2f}s.")
yield VlmPrediction(
text=generated_text,
generation_time=generation_time,
# MLX generate doesn't expose tokens directly, so we leave it empty
generated_tokens=[],
)
_log.debug("MLX model: Released global lock")

View File

@@ -239,6 +239,9 @@ class LayoutPostprocessor:
final_clusters = self._sort_clusters( final_clusters = self._sort_clusters(
self.regular_clusters + self.special_clusters, mode="id" self.regular_clusters + self.special_clusters, mode="id"
) )
# Conditionally process cells if not skipping cell assignment
if not self.options.skip_cell_assignment:
for cluster in final_clusters: for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells) cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any # Also sort cells in children if any
@@ -264,6 +267,8 @@ class LayoutPostprocessor:
if cluster.label in self.LABEL_REMAPPING: if cluster.label in self.LABEL_REMAPPING:
cluster.label = self.LABEL_REMAPPING[cluster.label] cluster.label = self.LABEL_REMAPPING[cluster.label]
# Conditionally assign cells to clusters
if not self.options.skip_cell_assignment:
# Initial cell assignment # Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters) clusters = self._assign_cells_to_clusters(clusters)
@@ -350,12 +355,15 @@ class LayoutPostprocessor:
b=max(c.bbox.b for c in contained), b=max(c.bbox.b for c in contained),
) )
# Collect all cells from children # Conditionally collect cells from children
if not self.options.skip_cell_assignment:
all_cells = [] all_cells = []
for child in contained: for child in contained:
all_cells.extend(child.cells) all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells) special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells) special.cells = self._sort_cells(special.cells)
else:
special.cells = []
picture_clusters = [ picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE c for c in special_clusters if c.label == DocItemLabel.PICTURE