New assembly code for latest model revision, updated prompt and parsing of doctags, updated logging

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
Maksym Lysak 2025-02-11 13:34:14 +01:00
parent 7c4ab5c716
commit 479ee239aa
2 changed files with 138 additions and 257 deletions

View File

@ -30,7 +30,6 @@ class SmolDoclingModel(BasePageModel):
def __init__(
self,
# artifacts_path: Path,
accelerator_options: AcceleratorOptions,
vlm_options: SmolDoclingOptions,
):
@ -76,6 +75,16 @@ class SmolDoclingModel(BasePageModel):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
print(
"Processed image resolution: {} x {}".format(
im_width, im_height
)
)
# populate page_tags with predicted doc tags
page_tags = ""
@ -103,8 +112,8 @@ class SmolDoclingModel(BasePageModel):
text=prompt, images=[hi_res_image], return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
prompt = prompt.replace("<end_of_utterance>", "")
print("In the model, starting to generate...")
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
@ -129,11 +138,9 @@ class SmolDoclingModel(BasePageModel):
tokens_per_second = num_tokens / generation_time
print("")
print(f"Page Inference Time: {inference_time:.2f} seconds")
print(f"Total tokens on page: {num_tokens:.2f}")
print(f"Tokens/sec: {tokens_per_second:.2f}")
print("")
print("Page predictions:")
print(page_tags)
page.predictions.doctags = DocTagsPrediction(tag_string=page_tags)
yield page

View File

@ -1,7 +1,8 @@
import itertools
import logging
import re
from io import BytesIO
# from io import BytesIO
from pathlib import Path
from typing import Optional
@ -45,6 +46,49 @@ class VlmPipeline(PaginatedPipeline):
# force_backend_text = True - get text from backend using bounding boxes predicted by SmolDoclingss
self.force_backend_text = pipeline_options.force_backend_text
###############################################
# Tag definitions and color mappings
###############################################
# Maps the recognized tag to a Docling label.
# Code items will be given DocItemLabel.CODE
self.tag_to_doclabel = {
"title": DocItemLabel.TITLE,
"document_index": DocItemLabel.DOCUMENT_INDEX,
"otsl": DocItemLabel.TABLE,
"section_header_level_1": DocItemLabel.SECTION_HEADER,
"checkbox_selected": DocItemLabel.CHECKBOX_SELECTED,
"checkbox_unselected": DocItemLabel.CHECKBOX_UNSELECTED,
"text": DocItemLabel.TEXT,
"page_header": DocItemLabel.PAGE_HEADER,
"page_footer": DocItemLabel.PAGE_FOOTER,
"formula": DocItemLabel.FORMULA,
"caption": DocItemLabel.CAPTION,
"picture": DocItemLabel.PICTURE,
"list_item": DocItemLabel.LIST_ITEM,
"footnote": DocItemLabel.FOOTNOTE,
"code": DocItemLabel.CODE,
}
# Maps each tag to an associated bounding box color.
self.tag_to_color = {
"title": "blue",
"document_index": "darkblue",
"otsl": "green",
"section_header_level_1": "purple",
"checkbox_selected": "black",
"checkbox_unselected": "gray",
"text": "red",
"page_header": "orange",
"page_footer": "cyan",
"formula": "pink",
"caption": "magenta",
"picture": "yellow",
"list_item": "brown",
"footnote": "darkred",
"code": "lightblue",
}
"""
if pipeline_options.artifacts_path is None:
self.artifacts_path = self.download_models_hf()
@ -136,6 +180,18 @@ class VlmPipeline(PaginatedPipeline):
def _turn_tags_into_doc(self, pages: list[Page]) -> DoclingDocument:
def extract_bounding_box(text_chunk: str) -> Optional[BoundingBox]:
"""Extracts <loc_...> bounding box coords from the chunk, normalized by / 500."""
coords = re.findall(r"<loc_(\d+)>", text_chunk)
if len(coords) == 4:
l, t, r, b = map(float, coords)
return BoundingBox(l=l / 500, t=t / 500, r=r / 500, b=b / 500)
return None
def extract_inner_text(text_chunk: str) -> str:
"""Strips all <...> tags inside the chunk to get the raw text content."""
return re.sub(r"<.*?>", "", text_chunk, flags=re.DOTALL).strip()
def extract_text_from_backend(page: Page, bbox: BoundingBox | None) -> str:
# Convert bounding box normalized to 0-100 into page coordinates for cropping
text = ""
@ -149,18 +205,7 @@ class VlmPipeline(PaginatedPipeline):
text = page._backend.get_text_in_rect(bbox)
return text
def extract_text(tag_content: str) -> str:
return re.sub(r"<.*?>", "", tag_content).strip()
def extract_bounding_box(tag_content: str) -> Optional[BoundingBox]:
locs = re.findall(r"<loc_(\d+)>", tag_content)
if len(locs) == 4:
l, t, r, b = map(float, locs)
l, t, r, b = [coord / 500.0 for coord in (l, t, r, b)]
return BoundingBox(l=l, t=t, r=r, b=b)
return None
def parse_texts(texts, tokens):
def otsl_parse_texts(texts, tokens):
split_word = TableToken.OTSL_NL.value
split_row_tokens = [
list(y)
@ -267,7 +312,7 @@ class VlmPipeline(PaginatedPipeline):
c_idx = 0
return table_cells, split_row_tokens
def extract_tokens_and_text(s: str):
def otsl_extract_tokens_and_text(s: str):
# Pattern to match anything enclosed by < > (including the angle brackets themselves)
pattern = r"(<[^>]+>)"
# Find all tokens (e.g. "<otsl>", "<loc_140>", etc.)
@ -291,8 +336,8 @@ class VlmPipeline(PaginatedPipeline):
return tokens, text_parts
def parse_table_content(otsl_content: str) -> TableData:
tokens, mixed_texts = extract_tokens_and_text(otsl_content)
table_cells, split_row_tokens = parse_texts(mixed_texts, tokens)
tokens, mixed_texts = otsl_extract_tokens_and_text(otsl_content)
table_cells, split_row_tokens = otsl_parse_texts(mixed_texts, tokens)
return TableData(
num_rows=len(split_row_tokens),
@ -302,15 +347,18 @@ class VlmPipeline(PaginatedPipeline):
table_cells=table_cells,
)
doc = DoclingDocument(name="Example Document")
current_group = None
doc = DoclingDocument(name="Document")
for pg_idx, page in enumerate(pages):
xml_content = ""
predicted_text = ""
if page.predictions.doctags:
xml_content = page.predictions.doctags.tag_string
pil_image = page.image
print("Doctags predicted for a page {}:".format(pg_idx))
print(page.predictions.doctags)
print("")
predicted_text = page.predictions.doctags.tag_string
image = page.image
page_no = pg_idx + 1
bounding_boxes = []
if page.size:
pg_width = page.size.width
@ -318,258 +366,84 @@ class VlmPipeline(PaginatedPipeline):
size = Size(width=pg_width, height=pg_height)
parent_page = doc.add_page(page_no=page_no, size=size)
lines = xml_content.split("\n")
bounding_boxes = []
"""
1. Finds all <tag>...</tag> blocks in the entire string (multi-line friendly) in the order they appear.
2. For each chunk, extracts bounding box (if any) and inner text.
3. Adds the item to a DoclingDocument structure with the right label.
4. Tracks bounding boxes + color in a separate list for later visualization.
"""
for line in lines:
line = line.strip()
line = line.replace("<doctag>", "")
if line.startswith("<paragraph>"):
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
# Regex for all recognized tags
tag_pattern = (
r"<(?P<tag>title|document_index|otsl|section_header_level_1|checkbox_selected|"
r"checkbox_unselected|text|page_header|page_footer|formula|caption|picture|"
r"list_item|footnote|code)>.*?</(?P=tag)>"
)
pattern = re.compile(tag_pattern, re.DOTALL)
if prov_item:
bounding_boxes.append((prov_item, "red"))
doc.add_text(
label=DocItemLabel.PARAGRAPH,
text=content,
parent=current_group,
prov=(
ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
if prov_item
else None
),
)
elif line.startswith("<title>"):
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
# Go through each match in order
for match in pattern.finditer(predicted_text):
full_chunk = match.group(0)
tag_name = match.group("tag")
if prov_item:
bounding_boxes.append((prov_item, "blue"))
current_group = doc.add_group(
label=GroupLabel.SECTION, name=content
)
doc.add_text(
label=DocItemLabel.TITLE,
text=content,
parent=current_group,
prov=(
ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
if prov_item
else None
),
)
bbox = extract_bounding_box(full_chunk)
doc_label = self.tag_to_doclabel.get(tag_name, DocItemLabel.PARAGRAPH)
color = self.tag_to_color.get(tag_name, "white")
elif line.startswith("<section_header_level_1>"):
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
# Store bounding box + color
if bbox:
bounding_boxes.append((bbox, color))
if prov_item:
bounding_boxes.append((prov_item, "green"))
current_group = doc.add_group(
label=GroupLabel.SECTION, name=content
)
doc.add_text(
label=DocItemLabel.SECTION_HEADER,
text=content,
parent=current_group,
prov=(
ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
if prov_item
else None
),
)
if tag_name == "otsl":
table_data = parse_table_content(full_chunk)
doc.add_table(data=table_data)
elif line.startswith("<otsl>"):
prov_item = extract_bounding_box(line)
if prov_item:
bounding_boxes.append((prov_item, "aquamarine"))
table_data = parse_table_content(line)
doc.add_table(data=table_data, parent=current_group)
elif line.startswith("<footnote>"):
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
if prov_item:
bounding_boxes.append((prov_item, "orange"))
doc.add_text(
label=DocItemLabel.FOOTNOTE,
text=content,
parent=current_group,
prov=(
ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
if prov_item
else None
),
)
elif line.startswith("<page_header>"):
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
if prov_item:
bounding_boxes.append((prov_item, "purple"))
doc.add_text(
label=DocItemLabel.PAGE_HEADER,
text=content,
parent=current_group,
prov=(
ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
if prov_item
else None
),
)
elif line.startswith("<page_footer>"):
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
if prov_item:
bounding_boxes.append((prov_item, "cyan"))
doc.add_text(
label=DocItemLabel.PAGE_FOOTER,
text=content,
parent=current_group,
prov=(
ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
if prov_item
else None
),
)
elif line.startswith("<picture>"):
bbox = extract_bounding_box(line)
if bbox:
bounding_boxes.append((bbox, "yellow"))
if pil_image:
# Convert bounding box normalized to 0-100 into pixel coordinates for cropping
width, height = pil_image.size
elif tag_name == "picture":
if image:
if bbox:
width, height = image.size
crop_box = (
int(bbox.l * width),
int(bbox.t * height),
int(bbox.r * width),
int(bbox.b * height),
)
cropped_image = pil_image.crop(crop_box)
cropped_image = image.crop(crop_box)
doc.add_picture(
parent=current_group,
image=ImageRef.from_pil(image=cropped_image, dpi=300),
parent=None,
image=ImageRef.from_pil(image=cropped_image, dpi=72),
prov=(
ProvenanceItem(
bbox=bbox, charspan=(0, 0), page_no=page_no
)
),
)
else:
if bbox:
doc.add_picture(
parent=None,
prov=ProvenanceItem(
bbox=bbox, charspan=(0, 0), page_no=page_no
),
)
else:
doc.add_picture(
parent=current_group,
prov=ProvenanceItem(
bbox=bbox, charspan=(0, 0), page_no=page_no
),
)
elif line.startswith("<list_item>"):
prov_item_inst = None
prov_item = extract_bounding_box(line)
else:
# For everything else, treat as text
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
content = extract_text_from_backend(page, bbox)
else:
content = extract_text(line)
if prov_item:
bounding_boxes.append((prov_item, "brown"))
prov_item_inst = ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
text_content = extract_inner_text(full_chunk)
# If it's code, wrap it with <pre><code> tags
if doc_label == DocItemLabel.CODE:
text_content = f"<pre><code>{text_content}</code></pre>"
doc.add_text(
label=DocItemLabel.LIST_ITEM,
text=content,
parent=current_group,
prov=prov_item_inst if prov_item_inst else None,
label=doc_label,
text=text_content,
prov=(
ProvenanceItem(bbox=bbox, charspan=(0, 0), page_no=page_no)
if bbox
else None
),
)
elif line.startswith("<caption>"):
prov_item_inst = None
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
if prov_item:
bounding_boxes.append((prov_item, "magenta"))
prov_item_inst = ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
doc.add_text(
label=DocItemLabel.PARAGRAPH,
text=content,
parent=current_group,
prov=prov_item_inst if prov_item_inst else None,
)
elif line.startswith("<checkbox_unselected>"):
prov_item_inst = None
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
if prov_item:
bounding_boxes.append((prov_item, "gray"))
prov_item_inst = ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
doc.add_text(
label=DocItemLabel.CHECKBOX_UNSELECTED,
text=content,
parent=current_group,
prov=prov_item_inst if prov_item_inst else None,
)
elif line.startswith("<checkbox_selected>"):
prov_item_inst = None
prov_item = extract_bounding_box(line)
if self.force_backend_text:
content = extract_text_from_backend(page, prov_item)
else:
content = extract_text(line)
if prov_item:
bounding_boxes.append((prov_item, "black"))
prov_item_inst = ProvenanceItem(
bbox=prov_item, charspan=(0, 0), page_no=page_no
)
doc.add_text(
label=DocItemLabel.CHECKBOX_SELECTED,
text=content,
parent=current_group,
prov=prov_item_inst if prov_item_inst else None,
)
# return doc, bounding_boxes
return doc
@classmethod