mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat: new code formula model (#2042)
* new code formula model Signed-off-by: mao <mao@lenny.zuvela.ibm.com> * new model on hf Signed-off-by: mao <mao@tabby.zuvela.ibm.com> * pre-commits Signed-off-by: mao <mao@login-c.zuvela.ibm.com> * remove MPS Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: mao <mao@lenny.zuvela.ibm.com> Signed-off-by: mao <mao@tabby.zuvela.ibm.com> Signed-off-by: mao <mao@login-c.zuvela.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: mao <mao@lenny.zuvela.ibm.com> Co-authored-by: mao <mao@tabby.zuvela.ibm.com> Co-authored-by: mao <mao@login-c.zuvela.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from collections import Counter
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
@@ -13,10 +12,11 @@ from docling_core.types.doc import (
|
||||
TextItem,
|
||||
)
|
||||
from docling_core.types.doc.labels import CodeLanguageLabel
|
||||
from PIL import Image, ImageOps
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
||||
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
||||
from docling.models.utils.hf_model_download import download_hf_model
|
||||
@@ -65,9 +65,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
"""
|
||||
|
||||
_model_repo_folder = "ds4sd--CodeFormula"
|
||||
_model_repo_folder = "ds4sd--CodeFormulaV2"
|
||||
elements_batch_size = 5
|
||||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
||||
images_scale = 1.67 # = 120 dpi, aligned with training data resolution
|
||||
expansion_factor = 0.18
|
||||
|
||||
def __init__(
|
||||
@@ -95,10 +95,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
self.options = options
|
||||
|
||||
if self.enabled:
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
from docling_ibm_models.code_formula_model.code_formula_predictor import (
|
||||
CodeFormulaPredictor,
|
||||
self.device = decide_device(
|
||||
accelerator_options.device,
|
||||
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
|
||||
)
|
||||
|
||||
if artifacts_path is None:
|
||||
@@ -106,11 +105,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
else:
|
||||
artifacts_path = artifacts_path / self._model_repo_folder
|
||||
|
||||
self.code_formula_model = CodeFormulaPredictor(
|
||||
artifacts_path=str(artifacts_path),
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
artifacts_path,
|
||||
)
|
||||
self._model_max_length = self._processor.tokenizer.model_max_length
|
||||
self._model = AutoModelForImageTextToText.from_pretrained(
|
||||
artifacts_path, device_map=self.device
|
||||
)
|
||||
self._model.eval()
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
@@ -119,8 +121,8 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
return download_hf_model(
|
||||
repo_id="ds4sd/CodeFormula",
|
||||
revision="v1.0.2",
|
||||
repo_id="ds4sd/CodeFormulaV2",
|
||||
revision="main",
|
||||
local_dir=local_dir,
|
||||
force=force,
|
||||
progress=progress,
|
||||
@@ -172,7 +174,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
- The second element is the extracted language if a match is found;
|
||||
otherwise, `None`.
|
||||
"""
|
||||
pattern = r"^<_([^_>]+)_>\s(.*)"
|
||||
pattern = r"^<_([^_>]+)_>\s*(.*)"
|
||||
match = re.match(pattern, input_string, flags=re.DOTALL)
|
||||
if match:
|
||||
language = str(match.group(1)) # the captured programming language
|
||||
@@ -203,81 +205,74 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
except ValueError:
|
||||
return CodeLanguageLabel.UNKNOWN
|
||||
|
||||
def _get_most_frequent_edge_color(self, pil_img: Image.Image):
|
||||
def _get_prompt(self, label: str) -> str:
|
||||
"""
|
||||
Compute the most frequent color along the outer edges of a PIL image.
|
||||
Constructs the prompt for the model based on the input label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pil_img : Image.Image
|
||||
A PIL Image in any mode (L, RGB, RGBA, etc.).
|
||||
label : str
|
||||
The type of input, either 'code' or 'formula'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(int) or (tuple): The most common edge color as a scalar (for grayscale) or
|
||||
tuple (for RGB/RGBA).
|
||||
str
|
||||
The constructed prompt including necessary tokens and query.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the label is not 'code' or 'formula'.
|
||||
"""
|
||||
# Convert to NumPy array for easy pixel access
|
||||
img_np = np.array(pil_img)
|
||||
|
||||
if img_np.ndim == 2:
|
||||
# Grayscale-like image: shape (H, W)
|
||||
# Extract edges: top row, bottom row, left col, right col
|
||||
top = img_np[0, :] # shape (W,)
|
||||
bottom = img_np[-1, :] # shape (W,)
|
||||
left = img_np[:, 0] # shape (H,)
|
||||
right = img_np[:, -1] # shape (H,)
|
||||
|
||||
# Concatenate all edges
|
||||
edges = np.concatenate([top, bottom, left, right])
|
||||
|
||||
# Count frequencies
|
||||
freq = Counter(edges.tolist())
|
||||
most_common_value, _ = freq.most_common(1)[0]
|
||||
return int(most_common_value) # single channel color
|
||||
|
||||
if label == "code":
|
||||
query = "<code>"
|
||||
elif label == "formula":
|
||||
query = "<formula>"
|
||||
else:
|
||||
# Color image: shape (H, W, C)
|
||||
top = img_np[0, :, :] # shape (W, C)
|
||||
bottom = img_np[-1, :, :] # shape (W, C)
|
||||
left = img_np[:, 0, :] # shape (H, C)
|
||||
right = img_np[:, -1, :] # shape (H, C)
|
||||
raise NotImplementedError("Label must be either code or formula")
|
||||
|
||||
# Concatenate edges along first axis
|
||||
edges = np.concatenate([top, bottom, left, right], axis=0)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": query}],
|
||||
},
|
||||
]
|
||||
|
||||
# Convert each color to a tuple for counting
|
||||
edges_as_tuples = [tuple(pixel) for pixel in edges]
|
||||
freq = Counter(edges_as_tuples)
|
||||
most_common_value, _ = freq.most_common(1)[0]
|
||||
return most_common_value # e.g. (R, G, B) or (R, G, B, A)
|
||||
prompt = self._processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
def _pad_with_most_frequent_edge_color(
|
||||
self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
|
||||
):
|
||||
return prompt
|
||||
|
||||
def _post_process(self, texts: list[str]) -> list[str]:
|
||||
"""
|
||||
Pads an image (PIL or NumPy array) using the most frequent edge color.
|
||||
Processes a list of text strings by truncating at '<end_of_utterance>' and
|
||||
removing a predefined set of unwanted substrings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img : Union[Image.Image, np.ndarray]
|
||||
The original image.
|
||||
padding : tuple
|
||||
Padding (left, top, right, bottom) in pixels.
|
||||
texts : list[str]
|
||||
A list of strings to be post-processed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Image.Image: A new PIL image with the specified padding.
|
||||
list[str]
|
||||
A list of cleaned strings with specified substrings removed and truncated at
|
||||
'<end_of_utterance>' if present.
|
||||
"""
|
||||
if isinstance(img, np.ndarray):
|
||||
pil_img = Image.fromarray(img)
|
||||
else:
|
||||
pil_img = img
|
||||
to_remove = ["</code>", "</formula>", "<loc_0><loc_0><loc_500><loc_500>"]
|
||||
|
||||
most_freq_color = self._get_most_frequent_edge_color(pil_img)
|
||||
def clean_text(text: str) -> str:
|
||||
idx = text.find("<end_of_utterance>")
|
||||
if idx != -1:
|
||||
text = text[:idx]
|
||||
|
||||
padded_img = ImageOps.expand(pil_img, border=padding, fill=most_freq_color)
|
||||
return padded_img
|
||||
for token in to_remove:
|
||||
if token in text:
|
||||
text = text.replace(token, "")
|
||||
return text.lstrip()
|
||||
|
||||
return [clean_text(t) for t in texts]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -308,14 +303,30 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
images: List[Union[Image.Image, np.ndarray]] = []
|
||||
elements: List[TextItem] = []
|
||||
for el in element_batch:
|
||||
assert isinstance(el.item, TextItem)
|
||||
elements.append(el.item)
|
||||
labels.append(el.item.label)
|
||||
images.append(
|
||||
self._pad_with_most_frequent_edge_color(el.image, (20, 10, 20, 10))
|
||||
)
|
||||
elements.append(el.item) # type: ignore[arg-type]
|
||||
labels.append(el.item.label) # type: ignore[attr-defined]
|
||||
images.append(el.image)
|
||||
|
||||
outputs = self.code_formula_model.predict(images, labels)
|
||||
prompts = [self._get_prompt(label) for label in labels]
|
||||
inputs = self._processor(
|
||||
text=prompts,
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.device)
|
||||
|
||||
gen_kwargs = dict(
|
||||
max_new_tokens=self._model_max_length - inputs.input_ids.shape[1],
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
generated_ids = self._model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
outputs = self._processor.batch_decode(
|
||||
generated_ids[:, inputs.input_ids.shape[1] :], skip_special_tokens=False
|
||||
)
|
||||
outputs = self._post_process(outputs)
|
||||
|
||||
for item, output in zip(elements, outputs):
|
||||
if isinstance(item, CodeItem):
|
||||
|
||||
Reference in New Issue
Block a user