diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index c753ac60..fa34a352 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -1,7 +1,7 @@ import math from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np from docling_core.types.doc import ( diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 36f26fef..9a3dce3f 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -282,6 +282,9 @@ class LayoutOptions(BaseModel): keep_empty_clusters: bool = ( False # Whether to keep clusters that contain no text cells ) + skip_cell_assignment: bool = ( + False # Skip cell-to-cluster assignment for VLM-only processing + ) model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2 diff --git a/docling/models/base_model.py b/docling/models/base_model.py index b0a43f40..21a358b3 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -1,11 +1,17 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Generic, Optional, Protocol, Type +from typing import Generic, Optional, Protocol, Type, Union +import numpy as np from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem +from PIL.Image import Image from typing_extensions import TypeVar -from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page +from docling.datamodel.base_models import ( + ItemAndImageEnrichmentElement, + Page, + VlmPrediction, +) from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import BaseOptions from docling.datamodel.settings import settings @@ -26,6 +32,46 @@ class BasePageModel(ABC): pass +class BaseVlmModel(ABC): + """Base class for Vision-Language Models that adds image processing capability.""" + + @abstractmethod + def process_images( + self, image_batch: Iterable[Union[Image, np.ndarray]] + ) -> Iterable[VlmPrediction]: + """Process raw images without page metadata.""" + + +class BaseVlmPageModel(BasePageModel, BaseVlmModel): + """Base implementation for VLM models that inherit from BasePageModel. + + Provides a default __call__ implementation that extracts images from pages, + processes them using process_images, and attaches results back to pages. + """ + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + """Extract images from pages, process them, and attach results back.""" + + @abstractmethod + def process_images( + self, + image_batch: Iterable[Union[Image, np.ndarray]], + prompt: Optional[str] = None, + ) -> Iterable[VlmPrediction]: + """Process raw images without page metadata. + + Args: + image_batch: Iterable of PIL Images or numpy arrays + prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string. + If vlm_options.prompt is callable and no prompt is provided, raises ValueError. + + Raises: + ValueError: If vlm_options.prompt is callable and no prompt parameter is provided. + """ + + EnrichElementT = TypeVar("EnrichElementT", default=NodeItem) diff --git a/docling/models/page_preprocessing_model.py b/docling/models/page_preprocessing_model.py index c4de3873..ccc43f53 100644 --- a/docling/models/page_preprocessing_model.py +++ b/docling/models/page_preprocessing_model.py @@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder class PagePreprocessingOptions(BaseModel): images_scale: Optional[float] + skip_cell_extraction: bool = ( + False # Skip text cell extraction for VLM-only processing + ) class PagePreprocessingModel(BasePageModel): @@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel): else: with TimeRecorder(conv_res, "page_parse"): page = self._populate_page_images(page) - page = self._parse_page_cells(conv_res, page) + if not self.options.skip_cell_extraction: + page = self._parse_page_cells(conv_res, page) yield page # Generate the page image and store it in the page object diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index d84925dd..36b819af 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -3,7 +3,10 @@ import logging import time from collections.abc import Iterable from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, Union + +import numpy as np +from PIL.Image import Image from docling.datamodel.accelerator_options import ( AcceleratorOptions, @@ -15,7 +18,7 @@ from docling.datamodel.pipeline_options_vlm_model import ( TransformersModelType, TransformersPromptStyle, ) -from docling.models.base_model import BasePageModel +from docling.models.base_model import BaseVlmPageModel from docling.models.utils.hf_model_download import ( HuggingFaceModelDownloadMixin, ) @@ -25,7 +28,7 @@ from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) -class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): +class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): def __init__( self, enabled: bool, @@ -159,7 +162,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix generation_time = time.time() - start_time generated_texts = self.processor.batch_decode( generated_ids[:, inputs["input_ids"].shape[1] :], - skip_special_tokens=False, + skip_special_tokens=True, )[0] num_tokens = len(generated_ids[0]) @@ -214,3 +217,101 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix raise RuntimeError( f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}." ) + + def process_images( + self, + image_batch: Iterable[Union[Image, np.ndarray]], + prompt: Optional[str] = None, + ) -> Iterable[VlmPrediction]: + """Process raw images without page metadata in a single batched inference call.""" + pil_images: list[Image] = [] + + for img in image_batch: + # Convert numpy array to PIL Image if needed + if isinstance(img, np.ndarray): + if img.ndim == 3 and img.shape[2] in [3, 4]: + from PIL import Image as PILImage + + pil_img = PILImage.fromarray(img.astype(np.uint8)) + elif img.ndim == 2: + from PIL import Image as PILImage + + pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L") + else: + raise ValueError(f"Unsupported numpy array shape: {img.shape}") + else: + pil_img = img + + # Ensure image is in RGB mode (handles RGBA, L, etc.) + if pil_img.mode != "RGB": + pil_img = pil_img.convert("RGB") + + pil_images.append(pil_img) + + if len(pil_images) == 0: + return + + # Handle prompt with priority: parameter > vlm_options.prompt > error + if prompt is not None: + user_prompt = prompt + elif not callable(self.vlm_options.prompt): + user_prompt = self.vlm_options.prompt + else: + raise ValueError( + "vlm_options.prompt is callable but no prompt parameter provided to process_images. " + "Please provide a prompt parameter when calling process_images directly." + ) + + formatted_prompt = self.formulate_prompt(user_prompt) + prompts: list[str] = [formatted_prompt] * len(pil_images) + + inputs = self.processor( + text=prompts, images=pil_images, return_tensors="pt", padding=True + ).to(self.device) + + start_time = time.time() + generated_ids = self.vlm_model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + use_cache=self.use_cache, + temperature=self.temperature, + generation_config=self.generation_config, + **self.vlm_options.extra_generation_config, + ) + generation_time = time.time() - start_time + + # Determine per-sample prompt lengths + try: + attention_mask = inputs["attention_mask"] + input_lengths: list[int] = attention_mask.sum(dim=1).tolist() + except KeyError: + tokenizer = ( + self.processor.tokenizer + ) # Expect tokenizer to be present when text is provided + pad_token_id = tokenizer.pad_token_id + if pad_token_id is not None: + input_lengths = ( + (inputs["input_ids"] != pad_token_id).sum(dim=1).tolist() + ) + else: + # Fallback: assume uniform prompt length (least accurate but preserves execution) + uniform_len = int(inputs["input_ids"].shape[1]) + input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0]) + + trimmed_sequences: list[list[int]] = [ + generated_ids[i, int(input_lengths[i]) :].tolist() + for i in range(generated_ids.shape[0]) + ] + decoded_texts: list[str] = self.processor.batch_decode( + trimmed_sequences, skip_special_tokens=True + ) + + # Logging tokens count for the first sample as a representative metric + if generated_ids.shape[0] > 0: + num_tokens = int(generated_ids[0].shape[0]) + _log.debug( + f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." + ) + + for text in decoded_texts: + yield VlmPrediction(text=text, generation_time=generation_time) diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 4b37fb48..294fe9cc 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -1,8 +1,12 @@ import logging +import threading import time from collections.abc import Iterable from pathlib import Path -from typing import Optional +from typing import Optional, Union + +import numpy as np +from PIL.Image import Image from docling.datamodel.accelerator_options import ( AcceleratorOptions, @@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import ( from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions -from docling.models.base_model import BasePageModel +from docling.models.base_model import BaseVlmPageModel from docling.models.utils.hf_model_download import ( HuggingFaceModelDownloadMixin, ) @@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) +# Global lock for MLX model calls - MLX models are not thread-safe +# All MLX models share this lock to prevent concurrent MLX operations +_MLX_GLOBAL_LOCK = threading.Lock() -class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): + +class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): def __init__( self, enabled: bool, @@ -92,51 +100,57 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): self.processor, self.config, user_prompt, num_images=1 ) - start_time = time.time() - _log.debug("start generating ...") + # MLX models are not thread-safe - use global lock to serialize access + with _MLX_GLOBAL_LOCK: + _log.debug( + "MLX model: Acquired global lock for __call__ method" + ) + start_time = time.time() + _log.debug("start generating ...") - # Call model to generate: - tokens: list[VlmPredictionToken] = [] + # Call model to generate: + tokens: list[VlmPredictionToken] = [] - output = "" - for token in self.stream_generate( - self.vlm_model, - self.processor, - prompt, - [hi_res_image], - max_tokens=self.max_tokens, - verbose=False, - temp=self.temperature, - ): - if len(token.logprobs.shape) == 1: - tokens.append( - VlmPredictionToken( - text=token.text, - token=token.token, - logprob=token.logprobs[token.token], - ) - ) - elif ( - len(token.logprobs.shape) == 2 - and token.logprobs.shape[0] == 1 + output = "" + for token in self.stream_generate( + self.vlm_model, + self.processor, + prompt, + [hi_res_image], + max_tokens=self.max_tokens, + verbose=False, + temp=self.temperature, ): - tokens.append( - VlmPredictionToken( - text=token.text, - token=token.token, - logprob=token.logprobs[0, token.token], + if len(token.logprobs.shape) == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[token.token], + ) + ) + elif ( + len(token.logprobs.shape) == 2 + and token.logprobs.shape[0] == 1 + ): + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[0, token.token], + ) + ) + else: + _log.warning( + f"incompatible shape for logprobs: {token.logprobs.shape}" ) - ) - else: - _log.warning( - f"incompatible shape for logprobs: {token.logprobs.shape}" - ) - output += token.text - if "" in token.text: - break + output += token.text + if "" in token.text: + break - generation_time = time.time() - start_time + generation_time = time.time() - start_time + _log.debug("MLX model: Released global lock") page_tags = output _log.debug( @@ -149,3 +163,82 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): ) yield page + + def process_images( + self, + image_batch: Iterable[Union[Image, np.ndarray]], + prompt: Optional[str] = None, + ) -> Iterable[VlmPrediction]: + from mlx_vlm import generate + + # MLX models are not thread-safe - use global lock to serialize access + with _MLX_GLOBAL_LOCK: + _log.debug("MLX model: Acquired global lock for thread safety") + for image in image_batch: + # Convert numpy array to PIL Image if needed + if isinstance(image, np.ndarray): + if image.ndim == 3 and image.shape[2] in [3, 4]: + # RGB or RGBA array + from PIL import Image as PILImage + + image = PILImage.fromarray(image.astype(np.uint8)) + elif image.ndim == 2: + # Grayscale array + from PIL import Image as PILImage + + image = PILImage.fromarray(image.astype(np.uint8), mode="L") + else: + raise ValueError( + f"Unsupported numpy array shape: {image.shape}" + ) + + # Ensure image is in RGB mode (handles RGBA, L, etc.) + if image.mode != "RGB": + image = image.convert("RGB") + + # Handle prompt with priority: parameter > vlm_options.prompt > error + if prompt is not None: + user_prompt = prompt + elif not callable(self.vlm_options.prompt): + user_prompt = self.vlm_options.prompt + else: + raise ValueError( + "vlm_options.prompt is callable but no prompt parameter provided to process_images. " + "Please provide a prompt parameter when calling process_images directly." + ) + + # Use the MLX chat template approach like in the __call__ method + formatted_prompt = self.apply_chat_template( + self.processor, self.config, user_prompt, num_images=1 + ) + + # Generate text from the image - MLX can accept PIL Images directly despite type annotations + start_time = time.time() + generated_result = generate( + self.vlm_model, + self.processor, + formatted_prompt, + image=image, # Pass PIL Image directly - much more efficient than disk I/O + verbose=False, + temp=self.temperature, + max_tokens=self.max_tokens, + ) + generation_time = time.time() - start_time + + # MLX generate returns a tuple (text, info_dict), extract just the text + if isinstance(generated_result, tuple): + generated_text = generated_result[0] + _log.debug( + f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}" + ) + else: + generated_text = generated_result + + _log.debug(f"Generated text in {generation_time:.2f}s.") + yield VlmPrediction( + text=generated_text, + generation_time=generation_time, + # MLX generate doesn't expose tokens directly, so we leave it empty + generated_tokens=[], + ) + _log.debug("MLX model: Released global lock") diff --git a/docling/utils/layout_postprocessor.py b/docling/utils/layout_postprocessor.py index edc6b396..0ac31954 100644 --- a/docling/utils/layout_postprocessor.py +++ b/docling/utils/layout_postprocessor.py @@ -239,15 +239,18 @@ class LayoutPostprocessor: final_clusters = self._sort_clusters( self.regular_clusters + self.special_clusters, mode="id" ) - for cluster in final_clusters: - cluster.cells = self._sort_cells(cluster.cells) - # Also sort cells in children if any - for child in cluster.children: - child.cells = self._sort_cells(child.cells) - assert self.page.parsed_page is not None - self.page.parsed_page.textline_cells = self.cells - self.page.parsed_page.has_lines = len(self.cells) > 0 + # Conditionally process cells if not skipping cell assignment + if not self.options.skip_cell_assignment: + for cluster in final_clusters: + cluster.cells = self._sort_cells(cluster.cells) + # Also sort cells in children if any + for child in cluster.children: + child.cells = self._sort_cells(child.cells) + + assert self.page.parsed_page is not None + self.page.parsed_page.textline_cells = self.cells + self.page.parsed_page.has_lines = len(self.cells) > 0 return final_clusters, self.cells @@ -264,36 +267,38 @@ class LayoutPostprocessor: if cluster.label in self.LABEL_REMAPPING: cluster.label = self.LABEL_REMAPPING[cluster.label] - # Initial cell assignment - clusters = self._assign_cells_to_clusters(clusters) + # Conditionally assign cells to clusters + if not self.options.skip_cell_assignment: + # Initial cell assignment + clusters = self._assign_cells_to_clusters(clusters) - # Remove clusters with no cells (if keep_empty_clusters is False), - # but always keep clusters with label DocItemLabel.FORMULA - if not self.options.keep_empty_clusters: - clusters = [ - cluster - for cluster in clusters - if cluster.cells or cluster.label == DocItemLabel.FORMULA - ] + # Remove clusters with no cells (if keep_empty_clusters is False), + # but always keep clusters with label DocItemLabel.FORMULA + if not self.options.keep_empty_clusters: + clusters = [ + cluster + for cluster in clusters + if cluster.cells or cluster.label == DocItemLabel.FORMULA + ] - # Handle orphaned cells - unassigned = self._find_unassigned_cells(clusters) - if unassigned and self.options.create_orphan_clusters: - next_id = max((c.id for c in self.all_clusters), default=0) + 1 - orphan_clusters = [] - for i, cell in enumerate(unassigned): - conf = cell.confidence + # Handle orphaned cells + unassigned = self._find_unassigned_cells(clusters) + if unassigned and self.options.create_orphan_clusters: + next_id = max((c.id for c in self.all_clusters), default=0) + 1 + orphan_clusters = [] + for i, cell in enumerate(unassigned): + conf = cell.confidence - orphan_clusters.append( - Cluster( - id=next_id + i, - label=DocItemLabel.TEXT, - bbox=cell.to_bounding_box(), - confidence=conf, - cells=[cell], + orphan_clusters.append( + Cluster( + id=next_id + i, + label=DocItemLabel.TEXT, + bbox=cell.to_bounding_box(), + confidence=conf, + cells=[cell], + ) ) - ) - clusters.extend(orphan_clusters) + clusters.extend(orphan_clusters) # Iterative refinement prev_count = len(clusters) + 1 @@ -350,12 +355,15 @@ class LayoutPostprocessor: b=max(c.bbox.b for c in contained), ) - # Collect all cells from children - all_cells = [] - for child in contained: - all_cells.extend(child.cells) - special.cells = self._deduplicate_cells(all_cells) - special.cells = self._sort_cells(special.cells) + # Conditionally collect cells from children + if not self.options.skip_cell_assignment: + all_cells = [] + for child in contained: + all_cells.extend(child.cells) + special.cells = self._deduplicate_cells(all_cells) + special.cells = self._sort_cells(special.cells) + else: + special.cells = [] picture_clusters = [ c for c in special_clusters if c.label == DocItemLabel.PICTURE