Prepare existing codes for use with new multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-13 14:02:19 +02:00
parent e2cca931be
commit 126944c7ee
7 changed files with 345 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
import math import math
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np import numpy as np
from docling_core.types.doc import ( from docling_core.types.doc import (

View File

@@ -282,6 +282,9 @@ class LayoutOptions(BaseModel):
keep_empty_clusters: bool = ( keep_empty_clusters: bool = (
False # Whether to keep clusters that contain no text cells False # Whether to keep clusters that contain no text cells
) )
skip_cell_assignment: bool = (
False # Skip cell-to-cluster assignment for VLM-only processing
)
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2 model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2

View File

@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Generic, Optional, Protocol, Type from typing import Generic, Optional, Protocol, Type, Union
import numpy as np
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from PIL.Image import Image
from typing_extensions import TypeVar from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page from docling.datamodel.base_models import (
ItemAndImageEnrichmentElement,
Page,
VlmPrediction,
)
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions from docling.datamodel.pipeline_options import BaseOptions
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
@@ -26,6 +32,46 @@ class BasePageModel(ABC):
pass pass
class BaseVlmModel(ABC):
"""Base class for Vision-Language Models that adds image processing capability."""
@abstractmethod
def process_images(
self, image_batch: Iterable[Union[Image, np.ndarray]]
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata."""
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
"""Base implementation for VLM models that inherit from BasePageModel.
Provides a default __call__ implementation that extracts images from pages,
processes them using process_images, and attaches results back to pages.
"""
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
"""Extract images from pages, process them, and attach results back."""
@abstractmethod
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string.
If vlm_options.prompt is callable and no prompt is provided, raises ValueError.
Raises:
ValueError: If vlm_options.prompt is callable and no prompt parameter is provided.
"""
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem) EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
class PagePreprocessingOptions(BaseModel): class PagePreprocessingOptions(BaseModel):
images_scale: Optional[float] images_scale: Optional[float]
skip_cell_extraction: bool = (
False # Skip text cell extraction for VLM-only processing
)
class PagePreprocessingModel(BasePageModel): class PagePreprocessingModel(BasePageModel):
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
else: else:
with TimeRecorder(conv_res, "page_parse"): with TimeRecorder(conv_res, "page_parse"):
page = self._populate_page_images(page) page = self._populate_page_images(page)
page = self._parse_page_cells(conv_res, page) if not self.options.skip_cell_extraction:
page = self._parse_page_cells(conv_res, page)
yield page yield page
# Generate the page image and store it in the page object # Generate the page image and store it in the page object

View File

@@ -3,7 +3,10 @@ import logging
import time import time
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
@@ -15,7 +18,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType, TransformersModelType,
TransformersPromptStyle, TransformersPromptStyle,
) )
from docling.models.base_model import BasePageModel from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@@ -25,7 +28,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
@@ -159,7 +162,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
generation_time = time.time() - start_time generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode( generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :], generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False, skip_special_tokens=True,
)[0] )[0]
num_tokens = len(generated_ids[0]) num_tokens = len(generated_ids[0])
@@ -214,3 +217,101 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
raise RuntimeError( raise RuntimeError(
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}." f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
) )
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call."""
pil_images: list[Image] = []
for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if len(pil_images) == 0:
return
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
formatted_prompt = self.formulate_prompt(user_prompt)
prompts: list[str] = [formatted_prompt] * len(pil_images)
inputs = self.processor(
text=prompts, images=pil_images, return_tensors="pt", padding=True
).to(self.device)
start_time = time.time()
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generation_time = time.time() - start_time
# Determine per-sample prompt lengths
try:
attention_mask = inputs["attention_mask"]
input_lengths: list[int] = attention_mask.sum(dim=1).tolist()
except KeyError:
tokenizer = (
self.processor.tokenizer
) # Expect tokenizer to be present when text is provided
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None:
input_lengths = (
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
)
else:
# Fallback: assume uniform prompt length (least accurate but preserves execution)
uniform_len = int(inputs["input_ids"].shape[1])
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
trimmed_sequences: list[list[int]] = [
generated_ids[i, int(input_lengths[i]) :].tolist()
for i in range(generated_ids.shape[0])
]
decoded_texts: list[str] = self.processor.batch_decode(
trimmed_sequences, skip_special_tokens=True
)
# Logging tokens count for the first sample as a representative metric
if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
for text in decoded_texts:
yield VlmPrediction(text=text, generation_time=generation_time)

View File

@@ -1,8 +1,12 @@
import logging import logging
import threading
import time import time
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
# Global lock for MLX model calls - MLX models are not thread-safe
# All MLX models share this lock to prevent concurrent MLX operations
_MLX_GLOBAL_LOCK = threading.Lock()
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
@@ -92,51 +100,57 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1 self.processor, self.config, user_prompt, num_images=1
) )
start_time = time.time() # MLX models are not thread-safe - use global lock to serialize access
_log.debug("start generating ...") with _MLX_GLOBAL_LOCK:
_log.debug(
"MLX model: Acquired global lock for __call__ method"
)
start_time = time.time()
_log.debug("start generating ...")
# Call model to generate: # Call model to generate:
tokens: list[VlmPredictionToken] = [] tokens: list[VlmPredictionToken] = []
output = "" output = ""
for token in self.stream_generate( for token in self.stream_generate(
self.vlm_model, self.vlm_model,
self.processor, self.processor,
prompt, prompt,
[hi_res_image], [hi_res_image],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
verbose=False, verbose=False,
temp=self.temperature, temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
): ):
tokens.append( if len(token.logprobs.shape) == 1:
VlmPredictionToken( tokens.append(
text=token.text, VlmPredictionToken(
token=token.token, text=token.text,
logprob=token.logprobs[0, token.token], token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
) )
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text output += token.text
if "</doctag>" in token.text: if "</doctag>" in token.text:
break break
generation_time = time.time() - start_time generation_time = time.time() - start_time
_log.debug("MLX model: Released global lock")
page_tags = output page_tags = output
_log.debug( _log.debug(
@@ -149,3 +163,82 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
) )
yield page yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
from mlx_vlm import generate
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug("MLX model: Acquired global lock for thread safety")
for image in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3, 4]:
# RGB or RGBA array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8))
elif image.ndim == 2:
# Grayscale array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
else:
raise ValueError(
f"Unsupported numpy array shape: {image.shape}"
)
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
# Use the MLX chat template approach like in the __call__ method
formatted_prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
# Generate text from the image - MLX can accept PIL Images directly despite type annotations
start_time = time.time()
generated_result = generate(
self.vlm_model,
self.processor,
formatted_prompt,
image=image, # Pass PIL Image directly - much more efficient than disk I/O
verbose=False,
temp=self.temperature,
max_tokens=self.max_tokens,
)
generation_time = time.time() - start_time
# MLX generate returns a tuple (text, info_dict), extract just the text
if isinstance(generated_result, tuple):
generated_text = generated_result[0]
_log.debug(
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
)
else:
generated_text = generated_result
_log.debug(f"Generated text in {generation_time:.2f}s.")
yield VlmPrediction(
text=generated_text,
generation_time=generation_time,
# MLX generate doesn't expose tokens directly, so we leave it empty
generated_tokens=[],
)
_log.debug("MLX model: Released global lock")

View File

@@ -239,15 +239,18 @@ class LayoutPostprocessor:
final_clusters = self._sort_clusters( final_clusters = self._sort_clusters(
self.regular_clusters + self.special_clusters, mode="id" self.regular_clusters + self.special_clusters, mode="id"
) )
for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
for child in cluster.children:
child.cells = self._sort_cells(child.cells)
assert self.page.parsed_page is not None # Conditionally process cells if not skipping cell assignment
self.page.parsed_page.textline_cells = self.cells if not self.options.skip_cell_assignment:
self.page.parsed_page.has_lines = len(self.cells) > 0 for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
for child in cluster.children:
child.cells = self._sort_cells(child.cells)
assert self.page.parsed_page is not None
self.page.parsed_page.textline_cells = self.cells
self.page.parsed_page.has_lines = len(self.cells) > 0
return final_clusters, self.cells return final_clusters, self.cells
@@ -264,36 +267,38 @@ class LayoutPostprocessor:
if cluster.label in self.LABEL_REMAPPING: if cluster.label in self.LABEL_REMAPPING:
cluster.label = self.LABEL_REMAPPING[cluster.label] cluster.label = self.LABEL_REMAPPING[cluster.label]
# Initial cell assignment # Conditionally assign cells to clusters
clusters = self._assign_cells_to_clusters(clusters) if not self.options.skip_cell_assignment:
# Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters)
# Remove clusters with no cells (if keep_empty_clusters is False), # Remove clusters with no cells (if keep_empty_clusters is False),
# but always keep clusters with label DocItemLabel.FORMULA # but always keep clusters with label DocItemLabel.FORMULA
if not self.options.keep_empty_clusters: if not self.options.keep_empty_clusters:
clusters = [ clusters = [
cluster cluster
for cluster in clusters for cluster in clusters
if cluster.cells or cluster.label == DocItemLabel.FORMULA if cluster.cells or cluster.label == DocItemLabel.FORMULA
] ]
# Handle orphaned cells # Handle orphaned cells
unassigned = self._find_unassigned_cells(clusters) unassigned = self._find_unassigned_cells(clusters)
if unassigned and self.options.create_orphan_clusters: if unassigned and self.options.create_orphan_clusters:
next_id = max((c.id for c in self.all_clusters), default=0) + 1 next_id = max((c.id for c in self.all_clusters), default=0) + 1
orphan_clusters = [] orphan_clusters = []
for i, cell in enumerate(unassigned): for i, cell in enumerate(unassigned):
conf = cell.confidence conf = cell.confidence
orphan_clusters.append( orphan_clusters.append(
Cluster( Cluster(
id=next_id + i, id=next_id + i,
label=DocItemLabel.TEXT, label=DocItemLabel.TEXT,
bbox=cell.to_bounding_box(), bbox=cell.to_bounding_box(),
confidence=conf, confidence=conf,
cells=[cell], cells=[cell],
)
) )
) clusters.extend(orphan_clusters)
clusters.extend(orphan_clusters)
# Iterative refinement # Iterative refinement
prev_count = len(clusters) + 1 prev_count = len(clusters) + 1
@@ -350,12 +355,15 @@ class LayoutPostprocessor:
b=max(c.bbox.b for c in contained), b=max(c.bbox.b for c in contained),
) )
# Collect all cells from children # Conditionally collect cells from children
all_cells = [] if not self.options.skip_cell_assignment:
for child in contained: all_cells = []
all_cells.extend(child.cells) for child in contained:
special.cells = self._deduplicate_cells(all_cells) all_cells.extend(child.cells)
special.cells = self._sort_cells(special.cells) special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
else:
special.cells = []
picture_clusters = [ picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE c for c in special_clusters if c.label == DocItemLabel.PICTURE