mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
doing some experiments with granite-docling
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
@@ -6,6 +7,7 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin, DocItem
|
||||
from PIL.Image import Image
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
@@ -27,6 +29,37 @@ _log = logging.getLogger(__name__)
|
||||
_MLX_GLOBAL_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class DoclingStopping:
|
||||
def __init__(self):
|
||||
self.pattern = re.compile(
|
||||
r"<([a-z\_\-]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>(<)?$"
|
||||
)
|
||||
|
||||
self.bboxs: list[BoundingBox] = []
|
||||
|
||||
def overlaps(self, text: str) -> bool:
|
||||
match = re.search(self.pattern, text)
|
||||
if match:
|
||||
tag_name = match.group(1) # First group: button
|
||||
loc1 = float(match.group(2)) # Second group: 100
|
||||
loc2 = float(match.group(3)) # Third group: 200
|
||||
loc3 = float(match.group(4)) # Fourth group: 150
|
||||
loc4 = float(match.group(5)) # Fifth group: 50
|
||||
|
||||
bbox = BoundingBox(
|
||||
l=loc1, b=loc2, r=loc3, t=loc4, coord_origin=CoordOrigin.BOTTOMLEFT
|
||||
)
|
||||
|
||||
for _ in self.bboxs:
|
||||
if bbox.intersection_over_self(_) > 1.0e-6:
|
||||
_log.info(f"{bbox} overlaps with {_}")
|
||||
return True
|
||||
|
||||
self.bboxs.append(bbox)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -68,6 +101,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
self.config = load_config(artifacts_path)
|
||||
|
||||
self._find_doctags_labels()
|
||||
|
||||
def _find_doctags_labels(self):
|
||||
"""Simple iteration over vocabulary"""
|
||||
tokenizer = (
|
||||
self.processor.tokenizer
|
||||
if hasattr(self.processor, "tokenizer")
|
||||
else self.processor
|
||||
)
|
||||
|
||||
self.special_tokens: dict[str, int] = {}
|
||||
if hasattr(tokenizer, "vocab"):
|
||||
# vocab is usually a dict mapping token_text -> token_id
|
||||
for token_text, token_id in tokenizer.vocab.items():
|
||||
if re.match(r"^<[a-z\_\-\d]+>$", token_text):
|
||||
print(f"Token ID: {token_id:6d} | Text: '{token_text}'")
|
||||
self.special_tokens[token_text] = token_id
|
||||
else:
|
||||
print("Tokenizer doesn't have a 'vocab' attribute")
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@@ -199,6 +252,8 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
tokens: list[VlmPredictionToken] = []
|
||||
output = ""
|
||||
|
||||
stopping_criteria = DoclingStopping()
|
||||
|
||||
# Use stream_generate for proper stop string handling
|
||||
for token in self.stream_generate(
|
||||
self.vlm_model,
|
||||
@@ -209,6 +264,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
verbose=False,
|
||||
temp=self.temperature,
|
||||
):
|
||||
_log.info(
|
||||
f"logprobs.shape: {token.logprobs.shape} with token: {token}"
|
||||
)
|
||||
|
||||
# Collect token information
|
||||
if len(token.logprobs.shape) == 1:
|
||||
tokens.append(
|
||||
@@ -218,6 +277,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
logprob=token.logprobs[token.token],
|
||||
)
|
||||
)
|
||||
if token.text in self.special_tokens:
|
||||
# Get logprobs for all special tokens
|
||||
special_token_logprobs = []
|
||||
for token_text, token_id in self.special_tokens.items():
|
||||
logprob = token.logprobs[token_id]
|
||||
special_token_logprobs.append(
|
||||
(token_text, token_id, logprob)
|
||||
)
|
||||
|
||||
# Sort by logprob (highest first) and take top 5
|
||||
top_5_special = sorted(
|
||||
special_token_logprobs, key=lambda x: x[2], reverse=True
|
||||
)[:5]
|
||||
|
||||
print("Top 5 special tokens by logprob:")
|
||||
for rank, (t, token_id, logprob) in enumerate(
|
||||
top_5_special, 1
|
||||
):
|
||||
print(f" {rank}. {t}: {logprob:0.3f}")
|
||||
|
||||
elif (
|
||||
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
|
||||
):
|
||||
@@ -228,6 +307,11 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
logprob=token.logprobs[0, token.token],
|
||||
)
|
||||
)
|
||||
|
||||
if token.text in self.special_tokens:
|
||||
for t, i in self.special_tokens.items():
|
||||
print(f"{t}: {token.logprobs[0, i]:0.3f}")
|
||||
|
||||
else:
|
||||
_log.warning(
|
||||
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
||||
@@ -235,6 +319,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
|
||||
output += token.text
|
||||
|
||||
if stopping_criteria.overlaps(output):
|
||||
_log.debug("Stopping generation due to overlapping bbox")
|
||||
break
|
||||
|
||||
# Check for any configured stop strings
|
||||
if self.vlm_options.stop_strings:
|
||||
if any(
|
||||
@@ -246,7 +334,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
_log.debug(
|
||||
_log.info(
|
||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user