doing some experiments with granite-docling

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar
2025-09-08 06:03:18 +02:00
parent 0e2f370f4f
commit ae9ec37cf1

View File

@@ -1,4 +1,5 @@
import logging import logging
import re
import threading import threading
import time import time
from collections.abc import Iterable from collections.abc import Iterable
@@ -6,6 +7,7 @@ from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
from docling_core.types.doc import BoundingBox, CoordOrigin, DocItem
from PIL.Image import Image from PIL.Image import Image
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
@@ -27,6 +29,37 @@ _log = logging.getLogger(__name__)
_MLX_GLOBAL_LOCK = threading.Lock() _MLX_GLOBAL_LOCK = threading.Lock()
class DoclingStopping:
def __init__(self):
self.pattern = re.compile(
r"<([a-z\_\-]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>(<)?$"
)
self.bboxs: list[BoundingBox] = []
def overlaps(self, text: str) -> bool:
match = re.search(self.pattern, text)
if match:
tag_name = match.group(1) # First group: button
loc1 = float(match.group(2)) # Second group: 100
loc2 = float(match.group(3)) # Third group: 200
loc3 = float(match.group(4)) # Fourth group: 150
loc4 = float(match.group(5)) # Fifth group: 50
bbox = BoundingBox(
l=loc1, b=loc2, r=loc3, t=loc4, coord_origin=CoordOrigin.BOTTOMLEFT
)
for _ in self.bboxs:
if bbox.intersection_over_self(_) > 1.0e-6:
_log.info(f"{bbox} overlaps with {_}")
return True
self.bboxs.append(bbox)
return False
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
@@ -68,6 +101,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
self.vlm_model, self.processor = load(artifacts_path) self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path) self.config = load_config(artifacts_path)
self._find_doctags_labels()
def _find_doctags_labels(self):
"""Simple iteration over vocabulary"""
tokenizer = (
self.processor.tokenizer
if hasattr(self.processor, "tokenizer")
else self.processor
)
self.special_tokens: dict[str, int] = {}
if hasattr(tokenizer, "vocab"):
# vocab is usually a dict mapping token_text -> token_id
for token_text, token_id in tokenizer.vocab.items():
if re.match(r"^<[a-z\_\-\d]+>$", token_text):
print(f"Token ID: {token_id:6d} | Text: '{token_text}'")
self.special_tokens[token_text] = token_id
else:
print("Tokenizer doesn't have a 'vocab' attribute")
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]: ) -> Iterable[Page]:
@@ -199,6 +252,8 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
tokens: list[VlmPredictionToken] = [] tokens: list[VlmPredictionToken] = []
output = "" output = ""
stopping_criteria = DoclingStopping()
# Use stream_generate for proper stop string handling # Use stream_generate for proper stop string handling
for token in self.stream_generate( for token in self.stream_generate(
self.vlm_model, self.vlm_model,
@@ -209,6 +264,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
verbose=False, verbose=False,
temp=self.temperature, temp=self.temperature,
): ):
_log.info(
f"logprobs.shape: {token.logprobs.shape} with token: {token}"
)
# Collect token information # Collect token information
if len(token.logprobs.shape) == 1: if len(token.logprobs.shape) == 1:
tokens.append( tokens.append(
@@ -218,6 +277,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
logprob=token.logprobs[token.token], logprob=token.logprobs[token.token],
) )
) )
if token.text in self.special_tokens:
# Get logprobs for all special tokens
special_token_logprobs = []
for token_text, token_id in self.special_tokens.items():
logprob = token.logprobs[token_id]
special_token_logprobs.append(
(token_text, token_id, logprob)
)
# Sort by logprob (highest first) and take top 5
top_5_special = sorted(
special_token_logprobs, key=lambda x: x[2], reverse=True
)[:5]
print("Top 5 special tokens by logprob:")
for rank, (t, token_id, logprob) in enumerate(
top_5_special, 1
):
print(f" {rank}. {t}: {logprob:0.3f}")
elif ( elif (
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1 len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
): ):
@@ -228,6 +307,11 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
logprob=token.logprobs[0, token.token], logprob=token.logprobs[0, token.token],
) )
) )
if token.text in self.special_tokens:
for t, i in self.special_tokens.items():
print(f"{t}: {token.logprobs[0, i]:0.3f}")
else: else:
_log.warning( _log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}" f"incompatible shape for logprobs: {token.logprobs.shape}"
@@ -235,6 +319,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
output += token.text output += token.text
if stopping_criteria.overlaps(output):
_log.debug("Stopping generation due to overlapping bbox")
break
# Check for any configured stop strings # Check for any configured stop strings
if self.vlm_options.stop_strings: if self.vlm_options.stop_strings:
if any( if any(
@@ -246,7 +334,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
generation_time = time.time() - start_time generation_time = time.time() - start_time
_log.debug( _log.info(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)." f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
) )