mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
New assembly code for latest model revision, updated prompt and parsing of doctags, updated logging
Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
parent
7c4ab5c716
commit
479ee239aa
@ -30,7 +30,6 @@ class SmolDoclingModel(BasePageModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# artifacts_path: Path,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
vlm_options: SmolDoclingOptions,
|
||||
):
|
||||
@ -76,6 +75,16 @@ class SmolDoclingModel(BasePageModel):
|
||||
assert page.size is not None
|
||||
|
||||
hi_res_image = page.get_image(scale=2.0) # 144dpi
|
||||
# hi_res_image = page.get_image(scale=1.0) # 72dpi
|
||||
|
||||
if hi_res_image is not None:
|
||||
im_width, im_height = hi_res_image.size
|
||||
print(
|
||||
"Processed image resolution: {} x {}".format(
|
||||
im_width, im_height
|
||||
)
|
||||
)
|
||||
|
||||
# populate page_tags with predicted doc tags
|
||||
page_tags = ""
|
||||
|
||||
@ -103,8 +112,8 @@ class SmolDoclingModel(BasePageModel):
|
||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
||||
)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
prompt = prompt.replace("<end_of_utterance>", "")
|
||||
|
||||
print("In the model, starting to generate...")
|
||||
start_time = time.time()
|
||||
# Call model to generate:
|
||||
generated_ids = self.vlm_model.generate(
|
||||
@ -129,11 +138,9 @@ class SmolDoclingModel(BasePageModel):
|
||||
tokens_per_second = num_tokens / generation_time
|
||||
print("")
|
||||
print(f"Page Inference Time: {inference_time:.2f} seconds")
|
||||
print(f"Total tokens on page: {num_tokens:.2f}")
|
||||
print(f"Tokens/sec: {tokens_per_second:.2f}")
|
||||
print("")
|
||||
print("Page predictions:")
|
||||
print(page_tags)
|
||||
|
||||
page.predictions.doctags = DocTagsPrediction(tag_string=page_tags)
|
||||
|
||||
yield page
|
||||
|
@ -1,7 +1,8 @@
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
# from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -45,6 +46,49 @@ class VlmPipeline(PaginatedPipeline):
|
||||
# force_backend_text = True - get text from backend using bounding boxes predicted by SmolDoclingss
|
||||
self.force_backend_text = pipeline_options.force_backend_text
|
||||
|
||||
###############################################
|
||||
# Tag definitions and color mappings
|
||||
###############################################
|
||||
|
||||
# Maps the recognized tag to a Docling label.
|
||||
# Code items will be given DocItemLabel.CODE
|
||||
self.tag_to_doclabel = {
|
||||
"title": DocItemLabel.TITLE,
|
||||
"document_index": DocItemLabel.DOCUMENT_INDEX,
|
||||
"otsl": DocItemLabel.TABLE,
|
||||
"section_header_level_1": DocItemLabel.SECTION_HEADER,
|
||||
"checkbox_selected": DocItemLabel.CHECKBOX_SELECTED,
|
||||
"checkbox_unselected": DocItemLabel.CHECKBOX_UNSELECTED,
|
||||
"text": DocItemLabel.TEXT,
|
||||
"page_header": DocItemLabel.PAGE_HEADER,
|
||||
"page_footer": DocItemLabel.PAGE_FOOTER,
|
||||
"formula": DocItemLabel.FORMULA,
|
||||
"caption": DocItemLabel.CAPTION,
|
||||
"picture": DocItemLabel.PICTURE,
|
||||
"list_item": DocItemLabel.LIST_ITEM,
|
||||
"footnote": DocItemLabel.FOOTNOTE,
|
||||
"code": DocItemLabel.CODE,
|
||||
}
|
||||
|
||||
# Maps each tag to an associated bounding box color.
|
||||
self.tag_to_color = {
|
||||
"title": "blue",
|
||||
"document_index": "darkblue",
|
||||
"otsl": "green",
|
||||
"section_header_level_1": "purple",
|
||||
"checkbox_selected": "black",
|
||||
"checkbox_unselected": "gray",
|
||||
"text": "red",
|
||||
"page_header": "orange",
|
||||
"page_footer": "cyan",
|
||||
"formula": "pink",
|
||||
"caption": "magenta",
|
||||
"picture": "yellow",
|
||||
"list_item": "brown",
|
||||
"footnote": "darkred",
|
||||
"code": "lightblue",
|
||||
}
|
||||
|
||||
"""
|
||||
if pipeline_options.artifacts_path is None:
|
||||
self.artifacts_path = self.download_models_hf()
|
||||
@ -136,6 +180,18 @@ class VlmPipeline(PaginatedPipeline):
|
||||
|
||||
def _turn_tags_into_doc(self, pages: list[Page]) -> DoclingDocument:
|
||||
|
||||
def extract_bounding_box(text_chunk: str) -> Optional[BoundingBox]:
|
||||
"""Extracts <loc_...> bounding box coords from the chunk, normalized by / 500."""
|
||||
coords = re.findall(r"<loc_(\d+)>", text_chunk)
|
||||
if len(coords) == 4:
|
||||
l, t, r, b = map(float, coords)
|
||||
return BoundingBox(l=l / 500, t=t / 500, r=r / 500, b=b / 500)
|
||||
return None
|
||||
|
||||
def extract_inner_text(text_chunk: str) -> str:
|
||||
"""Strips all <...> tags inside the chunk to get the raw text content."""
|
||||
return re.sub(r"<.*?>", "", text_chunk, flags=re.DOTALL).strip()
|
||||
|
||||
def extract_text_from_backend(page: Page, bbox: BoundingBox | None) -> str:
|
||||
# Convert bounding box normalized to 0-100 into page coordinates for cropping
|
||||
text = ""
|
||||
@ -149,18 +205,7 @@ class VlmPipeline(PaginatedPipeline):
|
||||
text = page._backend.get_text_in_rect(bbox)
|
||||
return text
|
||||
|
||||
def extract_text(tag_content: str) -> str:
|
||||
return re.sub(r"<.*?>", "", tag_content).strip()
|
||||
|
||||
def extract_bounding_box(tag_content: str) -> Optional[BoundingBox]:
|
||||
locs = re.findall(r"<loc_(\d+)>", tag_content)
|
||||
if len(locs) == 4:
|
||||
l, t, r, b = map(float, locs)
|
||||
l, t, r, b = [coord / 500.0 for coord in (l, t, r, b)]
|
||||
return BoundingBox(l=l, t=t, r=r, b=b)
|
||||
return None
|
||||
|
||||
def parse_texts(texts, tokens):
|
||||
def otsl_parse_texts(texts, tokens):
|
||||
split_word = TableToken.OTSL_NL.value
|
||||
split_row_tokens = [
|
||||
list(y)
|
||||
@ -267,7 +312,7 @@ class VlmPipeline(PaginatedPipeline):
|
||||
c_idx = 0
|
||||
return table_cells, split_row_tokens
|
||||
|
||||
def extract_tokens_and_text(s: str):
|
||||
def otsl_extract_tokens_and_text(s: str):
|
||||
# Pattern to match anything enclosed by < > (including the angle brackets themselves)
|
||||
pattern = r"(<[^>]+>)"
|
||||
# Find all tokens (e.g. "<otsl>", "<loc_140>", etc.)
|
||||
@ -291,8 +336,8 @@ class VlmPipeline(PaginatedPipeline):
|
||||
return tokens, text_parts
|
||||
|
||||
def parse_table_content(otsl_content: str) -> TableData:
|
||||
tokens, mixed_texts = extract_tokens_and_text(otsl_content)
|
||||
table_cells, split_row_tokens = parse_texts(mixed_texts, tokens)
|
||||
tokens, mixed_texts = otsl_extract_tokens_and_text(otsl_content)
|
||||
table_cells, split_row_tokens = otsl_parse_texts(mixed_texts, tokens)
|
||||
|
||||
return TableData(
|
||||
num_rows=len(split_row_tokens),
|
||||
@ -302,15 +347,18 @@ class VlmPipeline(PaginatedPipeline):
|
||||
table_cells=table_cells,
|
||||
)
|
||||
|
||||
doc = DoclingDocument(name="Example Document")
|
||||
current_group = None
|
||||
|
||||
doc = DoclingDocument(name="Document")
|
||||
for pg_idx, page in enumerate(pages):
|
||||
xml_content = ""
|
||||
predicted_text = ""
|
||||
if page.predictions.doctags:
|
||||
xml_content = page.predictions.doctags.tag_string
|
||||
pil_image = page.image
|
||||
print("Doctags predicted for a page {}:".format(pg_idx))
|
||||
print(page.predictions.doctags)
|
||||
print("")
|
||||
predicted_text = page.predictions.doctags.tag_string
|
||||
image = page.image
|
||||
page_no = pg_idx + 1
|
||||
bounding_boxes = []
|
||||
|
||||
if page.size:
|
||||
pg_width = page.size.width
|
||||
@ -318,258 +366,84 @@ class VlmPipeline(PaginatedPipeline):
|
||||
size = Size(width=pg_width, height=pg_height)
|
||||
parent_page = doc.add_page(page_no=page_no, size=size)
|
||||
|
||||
lines = xml_content.split("\n")
|
||||
bounding_boxes = []
|
||||
"""
|
||||
1. Finds all <tag>...</tag> blocks in the entire string (multi-line friendly) in the order they appear.
|
||||
2. For each chunk, extracts bounding box (if any) and inner text.
|
||||
3. Adds the item to a DoclingDocument structure with the right label.
|
||||
4. Tracks bounding boxes + color in a separate list for later visualization.
|
||||
"""
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
line = line.replace("<doctag>", "")
|
||||
if line.startswith("<paragraph>"):
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
# Regex for all recognized tags
|
||||
tag_pattern = (
|
||||
r"<(?P<tag>title|document_index|otsl|section_header_level_1|checkbox_selected|"
|
||||
r"checkbox_unselected|text|page_header|page_footer|formula|caption|picture|"
|
||||
r"list_item|footnote|code)>.*?</(?P=tag)>"
|
||||
)
|
||||
pattern = re.compile(tag_pattern, re.DOTALL)
|
||||
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "red"))
|
||||
doc.add_text(
|
||||
label=DocItemLabel.PARAGRAPH,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
if prov_item
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif line.startswith("<title>"):
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
# Go through each match in order
|
||||
for match in pattern.finditer(predicted_text):
|
||||
full_chunk = match.group(0)
|
||||
tag_name = match.group("tag")
|
||||
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "blue"))
|
||||
current_group = doc.add_group(
|
||||
label=GroupLabel.SECTION, name=content
|
||||
)
|
||||
doc.add_text(
|
||||
label=DocItemLabel.TITLE,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
if prov_item
|
||||
else None
|
||||
),
|
||||
)
|
||||
bbox = extract_bounding_box(full_chunk)
|
||||
doc_label = self.tag_to_doclabel.get(tag_name, DocItemLabel.PARAGRAPH)
|
||||
color = self.tag_to_color.get(tag_name, "white")
|
||||
|
||||
elif line.startswith("<section_header_level_1>"):
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "green"))
|
||||
current_group = doc.add_group(
|
||||
label=GroupLabel.SECTION, name=content
|
||||
)
|
||||
doc.add_text(
|
||||
label=DocItemLabel.SECTION_HEADER,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
if prov_item
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
elif line.startswith("<otsl>"):
|
||||
prov_item = extract_bounding_box(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "aquamarine"))
|
||||
|
||||
table_data = parse_table_content(line)
|
||||
doc.add_table(data=table_data, parent=current_group)
|
||||
|
||||
elif line.startswith("<footnote>"):
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "orange"))
|
||||
doc.add_text(
|
||||
label=DocItemLabel.FOOTNOTE,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
if prov_item
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
elif line.startswith("<page_header>"):
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "purple"))
|
||||
doc.add_text(
|
||||
label=DocItemLabel.PAGE_HEADER,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
if prov_item
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
elif line.startswith("<page_footer>"):
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "cyan"))
|
||||
doc.add_text(
|
||||
label=DocItemLabel.PAGE_FOOTER,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
if prov_item
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
elif line.startswith("<picture>"):
|
||||
bbox = extract_bounding_box(line)
|
||||
# Store bounding box + color
|
||||
if bbox:
|
||||
bounding_boxes.append((bbox, "yellow"))
|
||||
if pil_image:
|
||||
# Convert bounding box normalized to 0-100 into pixel coordinates for cropping
|
||||
width, height = pil_image.size
|
||||
bounding_boxes.append((bbox, color))
|
||||
|
||||
if tag_name == "otsl":
|
||||
table_data = parse_table_content(full_chunk)
|
||||
doc.add_table(data=table_data)
|
||||
|
||||
elif tag_name == "picture":
|
||||
if image:
|
||||
if bbox:
|
||||
width, height = image.size
|
||||
crop_box = (
|
||||
int(bbox.l * width),
|
||||
int(bbox.t * height),
|
||||
int(bbox.r * width),
|
||||
int(bbox.b * height),
|
||||
)
|
||||
|
||||
cropped_image = pil_image.crop(crop_box)
|
||||
cropped_image = image.crop(crop_box)
|
||||
doc.add_picture(
|
||||
parent=current_group,
|
||||
image=ImageRef.from_pil(image=cropped_image, dpi=300),
|
||||
parent=None,
|
||||
image=ImageRef.from_pil(image=cropped_image, dpi=72),
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=bbox, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
if bbox:
|
||||
doc.add_picture(
|
||||
parent=None,
|
||||
prov=ProvenanceItem(
|
||||
bbox=bbox, charspan=(0, 0), page_no=page_no
|
||||
),
|
||||
)
|
||||
else:
|
||||
doc.add_picture(
|
||||
parent=current_group,
|
||||
prov=ProvenanceItem(
|
||||
bbox=bbox, charspan=(0, 0), page_no=page_no
|
||||
# For everything else, treat as text
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, bbox)
|
||||
else:
|
||||
text_content = extract_inner_text(full_chunk)
|
||||
# If it's code, wrap it with <pre><code> tags
|
||||
if doc_label == DocItemLabel.CODE:
|
||||
text_content = f"<pre><code>{text_content}</code></pre>"
|
||||
doc.add_text(
|
||||
label=doc_label,
|
||||
text=text_content,
|
||||
prov=(
|
||||
ProvenanceItem(bbox=bbox, charspan=(0, 0), page_no=page_no)
|
||||
if bbox
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif line.startswith("<list_item>"):
|
||||
prov_item_inst = None
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "brown"))
|
||||
prov_item_inst = ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
doc.add_text(
|
||||
label=DocItemLabel.LIST_ITEM,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=prov_item_inst if prov_item_inst else None,
|
||||
)
|
||||
|
||||
elif line.startswith("<caption>"):
|
||||
prov_item_inst = None
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "magenta"))
|
||||
prov_item_inst = ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
doc.add_text(
|
||||
label=DocItemLabel.PARAGRAPH,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=prov_item_inst if prov_item_inst else None,
|
||||
)
|
||||
elif line.startswith("<checkbox_unselected>"):
|
||||
prov_item_inst = None
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "gray"))
|
||||
prov_item_inst = ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
doc.add_text(
|
||||
label=DocItemLabel.CHECKBOX_UNSELECTED,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=prov_item_inst if prov_item_inst else None,
|
||||
)
|
||||
|
||||
elif line.startswith("<checkbox_selected>"):
|
||||
prov_item_inst = None
|
||||
prov_item = extract_bounding_box(line)
|
||||
if self.force_backend_text:
|
||||
content = extract_text_from_backend(page, prov_item)
|
||||
else:
|
||||
content = extract_text(line)
|
||||
if prov_item:
|
||||
bounding_boxes.append((prov_item, "black"))
|
||||
prov_item_inst = ProvenanceItem(
|
||||
bbox=prov_item, charspan=(0, 0), page_no=page_no
|
||||
)
|
||||
doc.add_text(
|
||||
label=DocItemLabel.CHECKBOX_SELECTED,
|
||||
text=content,
|
||||
parent=current_group,
|
||||
prov=prov_item_inst if prov_item_inst else None,
|
||||
)
|
||||
# return doc, bounding_boxes
|
||||
return doc
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user