diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 52e786f2..1f610b6b 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -1,4 +1,5 @@ import logging +import re import threading import time from collections.abc import Iterable @@ -6,6 +7,7 @@ from pathlib import Path from typing import Optional, Union import numpy as np +from docling_core.types.doc import BoundingBox, CoordOrigin, DocItem from PIL.Image import Image from docling.datamodel.accelerator_options import ( @@ -27,6 +29,37 @@ _log = logging.getLogger(__name__) _MLX_GLOBAL_LOCK = threading.Lock() +class DoclingStopping: + def __init__(self): + self.pattern = re.compile( + r"<([a-z\_\-]+)>(<)?$" + ) + + self.bboxs: list[BoundingBox] = [] + + def overlaps(self, text: str) -> bool: + match = re.search(self.pattern, text) + if match: + tag_name = match.group(1) # First group: button + loc1 = float(match.group(2)) # Second group: 100 + loc2 = float(match.group(3)) # Third group: 200 + loc3 = float(match.group(4)) # Fourth group: 150 + loc4 = float(match.group(5)) # Fifth group: 50 + + bbox = BoundingBox( + l=loc1, b=loc2, r=loc3, t=loc4, coord_origin=CoordOrigin.BOTTOMLEFT + ) + + for _ in self.bboxs: + if bbox.intersection_over_self(_) > 1.0e-6: + _log.info(f"{bbox} overlaps with {_}") + return True + + self.bboxs.append(bbox) + + return False + + class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): def __init__( self, @@ -68,6 +101,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): self.vlm_model, self.processor = load(artifacts_path) self.config = load_config(artifacts_path) + self._find_doctags_labels() + + def _find_doctags_labels(self): + """Simple iteration over vocabulary""" + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + + self.special_tokens: dict[str, int] = {} + if hasattr(tokenizer, "vocab"): + # vocab is usually a dict mapping token_text -> token_id + for token_text, token_id in tokenizer.vocab.items(): + if re.match(r"^<[a-z\_\-\d]+>$", token_text): + print(f"Token ID: {token_id:6d} | Text: '{token_text}'") + self.special_tokens[token_text] = token_id + else: + print("Tokenizer doesn't have a 'vocab' attribute") + def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -199,6 +252,8 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): tokens: list[VlmPredictionToken] = [] output = "" + stopping_criteria = DoclingStopping() + # Use stream_generate for proper stop string handling for token in self.stream_generate( self.vlm_model, @@ -209,6 +264,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): verbose=False, temp=self.temperature, ): + _log.info( + f"logprobs.shape: {token.logprobs.shape} with token: {token}" + ) + # Collect token information if len(token.logprobs.shape) == 1: tokens.append( @@ -218,6 +277,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): logprob=token.logprobs[token.token], ) ) + if token.text in self.special_tokens: + # Get logprobs for all special tokens + special_token_logprobs = [] + for token_text, token_id in self.special_tokens.items(): + logprob = token.logprobs[token_id] + special_token_logprobs.append( + (token_text, token_id, logprob) + ) + + # Sort by logprob (highest first) and take top 5 + top_5_special = sorted( + special_token_logprobs, key=lambda x: x[2], reverse=True + )[:5] + + print("Top 5 special tokens by logprob:") + for rank, (t, token_id, logprob) in enumerate( + top_5_special, 1 + ): + print(f" {rank}. {t}: {logprob:0.3f}") + elif ( len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1 ): @@ -228,6 +307,11 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): logprob=token.logprobs[0, token.token], ) ) + + if token.text in self.special_tokens: + for t, i in self.special_tokens.items(): + print(f"{t}: {token.logprobs[0, i]:0.3f}") + else: _log.warning( f"incompatible shape for logprobs: {token.logprobs.shape}" @@ -235,6 +319,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): output += token.text + if stopping_criteria.overlaps(output): + _log.debug("Stopping generation due to overlapping bbox") + break + # Check for any configured stop strings if self.vlm_options.stop_strings: if any( @@ -246,7 +334,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): generation_time = time.time() - start_time - _log.debug( + _log.info( f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)." )