feat: new code formula model (#2042)

* new code formula model

Signed-off-by: mao <mao@lenny.zuvela.ibm.com>

* new model on hf

Signed-off-by: mao <mao@tabby.zuvela.ibm.com>

* pre-commits

Signed-off-by: mao <mao@login-c.zuvela.ibm.com>

* remove MPS

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: mao <mao@lenny.zuvela.ibm.com>
Signed-off-by: mao <mao@tabby.zuvela.ibm.com>
Signed-off-by: mao <mao@login-c.zuvela.ibm.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Co-authored-by: mao <mao@lenny.zuvela.ibm.com>
Co-authored-by: mao <mao@tabby.zuvela.ibm.com>
Co-authored-by: mao <mao@login-c.zuvela.ibm.com>
Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Matteo
2025-08-18 16:01:46 +02:00
committed by GitHub
parent c3a7d1d999
commit d2494da8b8

View File

@@ -1,5 +1,4 @@
import re
from collections import Counter
from collections.abc import Iterable
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union
@@ -13,10 +12,11 @@ from docling_core.types.doc import (
TextItem,
)
from docling_core.types.doc.labels import CodeLanguageLabel
from PIL import Image, ImageOps
from PIL import Image
from pydantic import BaseModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.models.utils.hf_model_download import download_hf_model
@@ -65,9 +65,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
Processes the given batch of elements and enriches them with predictions.
"""
_model_repo_folder = "ds4sd--CodeFormula"
_model_repo_folder = "ds4sd--CodeFormulaV2"
elements_batch_size = 5
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
images_scale = 1.67 # = 120 dpi, aligned with training data resolution
expansion_factor = 0.18
def __init__(
@@ -95,10 +95,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
self.options = options
if self.enabled:
device = decide_device(accelerator_options.device)
from docling_ibm_models.code_formula_model.code_formula_predictor import (
CodeFormulaPredictor,
self.device = decide_device(
accelerator_options.device,
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
)
if artifacts_path is None:
@@ -106,11 +105,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
else:
artifacts_path = artifacts_path / self._model_repo_folder
self.code_formula_model = CodeFormulaPredictor(
artifacts_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
self._processor = AutoProcessor.from_pretrained(
artifacts_path,
)
self._model_max_length = self._processor.tokenizer.model_max_length
self._model = AutoModelForImageTextToText.from_pretrained(
artifacts_path, device_map=self.device
)
self._model.eval()
@staticmethod
def download_models(
@@ -119,8 +121,8 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
progress: bool = False,
) -> Path:
return download_hf_model(
repo_id="ds4sd/CodeFormula",
revision="v1.0.2",
repo_id="ds4sd/CodeFormulaV2",
revision="main",
local_dir=local_dir,
force=force,
progress=progress,
@@ -172,7 +174,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
- The second element is the extracted language if a match is found;
otherwise, `None`.
"""
pattern = r"^<_([^_>]+)_>\s(.*)"
pattern = r"^<_([^_>]+)_>\s*(.*)"
match = re.match(pattern, input_string, flags=re.DOTALL)
if match:
language = str(match.group(1)) # the captured programming language
@@ -203,81 +205,74 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
except ValueError:
return CodeLanguageLabel.UNKNOWN
def _get_most_frequent_edge_color(self, pil_img: Image.Image):
def _get_prompt(self, label: str) -> str:
"""
Compute the most frequent color along the outer edges of a PIL image.
Constructs the prompt for the model based on the input label.
Parameters
----------
pil_img : Image.Image
A PIL Image in any mode (L, RGB, RGBA, etc.).
label : str
The type of input, either 'code' or 'formula'.
Returns
-------
(int) or (tuple): The most common edge color as a scalar (for grayscale) or
tuple (for RGB/RGBA).
str
The constructed prompt including necessary tokens and query.
Raises
------
NotImplementedError
If the label is not 'code' or 'formula'.
"""
# Convert to NumPy array for easy pixel access
img_np = np.array(pil_img)
if img_np.ndim == 2:
# Grayscale-like image: shape (H, W)
# Extract edges: top row, bottom row, left col, right col
top = img_np[0, :] # shape (W,)
bottom = img_np[-1, :] # shape (W,)
left = img_np[:, 0] # shape (H,)
right = img_np[:, -1] # shape (H,)
# Concatenate all edges
edges = np.concatenate([top, bottom, left, right])
# Count frequencies
freq = Counter(edges.tolist())
most_common_value, _ = freq.most_common(1)[0]
return int(most_common_value) # single channel color
if label == "code":
query = "<code>"
elif label == "formula":
query = "<formula>"
else:
# Color image: shape (H, W, C)
top = img_np[0, :, :] # shape (W, C)
bottom = img_np[-1, :, :] # shape (W, C)
left = img_np[:, 0, :] # shape (H, C)
right = img_np[:, -1, :] # shape (H, C)
raise NotImplementedError("Label must be either code or formula")
# Concatenate edges along first axis
edges = np.concatenate([top, bottom, left, right], axis=0)
messages = [
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": query}],
},
]
# Convert each color to a tuple for counting
edges_as_tuples = [tuple(pixel) for pixel in edges]
freq = Counter(edges_as_tuples)
most_common_value, _ = freq.most_common(1)[0]
return most_common_value # e.g. (R, G, B) or (R, G, B, A)
prompt = self._processor.apply_chat_template(
messages, add_generation_prompt=True
)
def _pad_with_most_frequent_edge_color(
self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
):
return prompt
def _post_process(self, texts: list[str]) -> list[str]:
"""
Pads an image (PIL or NumPy array) using the most frequent edge color.
Processes a list of text strings by truncating at '<end_of_utterance>' and
removing a predefined set of unwanted substrings.
Parameters
----------
img : Union[Image.Image, np.ndarray]
The original image.
padding : tuple
Padding (left, top, right, bottom) in pixels.
texts : list[str]
A list of strings to be post-processed.
Returns
-------
Image.Image: A new PIL image with the specified padding.
list[str]
A list of cleaned strings with specified substrings removed and truncated at
'<end_of_utterance>' if present.
"""
if isinstance(img, np.ndarray):
pil_img = Image.fromarray(img)
else:
pil_img = img
to_remove = ["</code>", "</formula>", "<loc_0><loc_0><loc_500><loc_500>"]
most_freq_color = self._get_most_frequent_edge_color(pil_img)
def clean_text(text: str) -> str:
idx = text.find("<end_of_utterance>")
if idx != -1:
text = text[:idx]
padded_img = ImageOps.expand(pil_img, border=padding, fill=most_freq_color)
return padded_img
for token in to_remove:
if token in text:
text = text.replace(token, "")
return text.lstrip()
return [clean_text(t) for t in texts]
def __call__(
self,
@@ -308,14 +303,30 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
images: List[Union[Image.Image, np.ndarray]] = []
elements: List[TextItem] = []
for el in element_batch:
assert isinstance(el.item, TextItem)
elements.append(el.item)
labels.append(el.item.label)
images.append(
self._pad_with_most_frequent_edge_color(el.image, (20, 10, 20, 10))
)
elements.append(el.item) # type: ignore[arg-type]
labels.append(el.item.label) # type: ignore[attr-defined]
images.append(el.image)
outputs = self.code_formula_model.predict(images, labels)
prompts = [self._get_prompt(label) for label in labels]
inputs = self._processor(
text=prompts,
images=images,
return_tensors="pt",
)
inputs = inputs.to(self.device)
gen_kwargs = dict(
max_new_tokens=self._model_max_length - inputs.input_ids.shape[1],
use_cache=True,
do_sample=False,
)
generated_ids = self._model.generate(**inputs, **gen_kwargs)
outputs = self._processor.batch_decode(
generated_ids[:, inputs.input_ids.shape[1] :], skip_special_tokens=False
)
outputs = self._post_process(outputs)
for item, output in zip(elements, outputs):
if isinstance(item, CodeItem):