mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
doing some experiments with granite-docling
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@@ -6,6 +7,7 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from docling_core.types.doc import BoundingBox, CoordOrigin, DocItem
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import (
|
||||||
@@ -27,6 +29,37 @@ _log = logging.getLogger(__name__)
|
|||||||
_MLX_GLOBAL_LOCK = threading.Lock()
|
_MLX_GLOBAL_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class DoclingStopping:
|
||||||
|
def __init__(self):
|
||||||
|
self.pattern = re.compile(
|
||||||
|
r"<([a-z\_\-]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>(<)?$"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bboxs: list[BoundingBox] = []
|
||||||
|
|
||||||
|
def overlaps(self, text: str) -> bool:
|
||||||
|
match = re.search(self.pattern, text)
|
||||||
|
if match:
|
||||||
|
tag_name = match.group(1) # First group: button
|
||||||
|
loc1 = float(match.group(2)) # Second group: 100
|
||||||
|
loc2 = float(match.group(3)) # Third group: 200
|
||||||
|
loc3 = float(match.group(4)) # Fourth group: 150
|
||||||
|
loc4 = float(match.group(5)) # Fifth group: 50
|
||||||
|
|
||||||
|
bbox = BoundingBox(
|
||||||
|
l=loc1, b=loc2, r=loc3, t=loc4, coord_origin=CoordOrigin.BOTTOMLEFT
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in self.bboxs:
|
||||||
|
if bbox.intersection_over_self(_) > 1.0e-6:
|
||||||
|
_log.info(f"{bbox} overlaps with {_}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.bboxs.append(bbox)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -68,6 +101,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
self.vlm_model, self.processor = load(artifacts_path)
|
self.vlm_model, self.processor = load(artifacts_path)
|
||||||
self.config = load_config(artifacts_path)
|
self.config = load_config(artifacts_path)
|
||||||
|
|
||||||
|
self._find_doctags_labels()
|
||||||
|
|
||||||
|
def _find_doctags_labels(self):
|
||||||
|
"""Simple iteration over vocabulary"""
|
||||||
|
tokenizer = (
|
||||||
|
self.processor.tokenizer
|
||||||
|
if hasattr(self.processor, "tokenizer")
|
||||||
|
else self.processor
|
||||||
|
)
|
||||||
|
|
||||||
|
self.special_tokens: dict[str, int] = {}
|
||||||
|
if hasattr(tokenizer, "vocab"):
|
||||||
|
# vocab is usually a dict mapping token_text -> token_id
|
||||||
|
for token_text, token_id in tokenizer.vocab.items():
|
||||||
|
if re.match(r"^<[a-z\_\-\d]+>$", token_text):
|
||||||
|
print(f"Token ID: {token_id:6d} | Text: '{token_text}'")
|
||||||
|
self.special_tokens[token_text] = token_id
|
||||||
|
else:
|
||||||
|
print("Tokenizer doesn't have a 'vocab' attribute")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
@@ -199,6 +252,8 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
tokens: list[VlmPredictionToken] = []
|
tokens: list[VlmPredictionToken] = []
|
||||||
output = ""
|
output = ""
|
||||||
|
|
||||||
|
stopping_criteria = DoclingStopping()
|
||||||
|
|
||||||
# Use stream_generate for proper stop string handling
|
# Use stream_generate for proper stop string handling
|
||||||
for token in self.stream_generate(
|
for token in self.stream_generate(
|
||||||
self.vlm_model,
|
self.vlm_model,
|
||||||
@@ -209,6 +264,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
verbose=False,
|
verbose=False,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
):
|
):
|
||||||
|
_log.info(
|
||||||
|
f"logprobs.shape: {token.logprobs.shape} with token: {token}"
|
||||||
|
)
|
||||||
|
|
||||||
# Collect token information
|
# Collect token information
|
||||||
if len(token.logprobs.shape) == 1:
|
if len(token.logprobs.shape) == 1:
|
||||||
tokens.append(
|
tokens.append(
|
||||||
@@ -218,6 +277,26 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
logprob=token.logprobs[token.token],
|
logprob=token.logprobs[token.token],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if token.text in self.special_tokens:
|
||||||
|
# Get logprobs for all special tokens
|
||||||
|
special_token_logprobs = []
|
||||||
|
for token_text, token_id in self.special_tokens.items():
|
||||||
|
logprob = token.logprobs[token_id]
|
||||||
|
special_token_logprobs.append(
|
||||||
|
(token_text, token_id, logprob)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by logprob (highest first) and take top 5
|
||||||
|
top_5_special = sorted(
|
||||||
|
special_token_logprobs, key=lambda x: x[2], reverse=True
|
||||||
|
)[:5]
|
||||||
|
|
||||||
|
print("Top 5 special tokens by logprob:")
|
||||||
|
for rank, (t, token_id, logprob) in enumerate(
|
||||||
|
top_5_special, 1
|
||||||
|
):
|
||||||
|
print(f" {rank}. {t}: {logprob:0.3f}")
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
|
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
|
||||||
):
|
):
|
||||||
@@ -228,6 +307,11 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
logprob=token.logprobs[0, token.token],
|
logprob=token.logprobs[0, token.token],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if token.text in self.special_tokens:
|
||||||
|
for t, i in self.special_tokens.items():
|
||||||
|
print(f"{t}: {token.logprobs[0, i]:0.3f}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_log.warning(
|
_log.warning(
|
||||||
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
||||||
@@ -235,6 +319,10 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
|
|
||||||
output += token.text
|
output += token.text
|
||||||
|
|
||||||
|
if stopping_criteria.overlaps(output):
|
||||||
|
_log.debug("Stopping generation due to overlapping bbox")
|
||||||
|
break
|
||||||
|
|
||||||
# Check for any configured stop strings
|
# Check for any configured stop strings
|
||||||
if self.vlm_options.stop_strings:
|
if self.vlm_options.stop_strings:
|
||||||
if any(
|
if any(
|
||||||
@@ -246,7 +334,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
|
|
||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
_log.debug(
|
_log.info(
|
||||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
|
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user