From 479ee239aa7d20e0e0b997e580dfb0c5f01297df Mon Sep 17 00:00:00 2001 From: Maksym Lysak Date: Tue, 11 Feb 2025 13:34:14 +0100 Subject: [PATCH] New assembly code for latest model revision, updated prompt and parsing of doctags, updated logging Signed-off-by: Maksym Lysak --- docling/models/smol_docling_model.py | 17 +- docling/pipeline/vlm_pipeline.py | 378 +++++++++------------------ 2 files changed, 138 insertions(+), 257 deletions(-) diff --git a/docling/models/smol_docling_model.py b/docling/models/smol_docling_model.py index 79d2affd..7bbf228a 100644 --- a/docling/models/smol_docling_model.py +++ b/docling/models/smol_docling_model.py @@ -30,7 +30,6 @@ class SmolDoclingModel(BasePageModel): def __init__( self, - # artifacts_path: Path, accelerator_options: AcceleratorOptions, vlm_options: SmolDoclingOptions, ): @@ -76,6 +75,16 @@ class SmolDoclingModel(BasePageModel): assert page.size is not None hi_res_image = page.get_image(scale=2.0) # 144dpi + # hi_res_image = page.get_image(scale=1.0) # 72dpi + + if hi_res_image is not None: + im_width, im_height = hi_res_image.size + print( + "Processed image resolution: {} x {}".format( + im_width, im_height + ) + ) + # populate page_tags with predicted doc tags page_tags = "" @@ -103,8 +112,8 @@ class SmolDoclingModel(BasePageModel): text=prompt, images=[hi_res_image], return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} - prompt = prompt.replace("", "") + print("In the model, starting to generate...") start_time = time.time() # Call model to generate: generated_ids = self.vlm_model.generate( @@ -129,11 +138,9 @@ class SmolDoclingModel(BasePageModel): tokens_per_second = num_tokens / generation_time print("") print(f"Page Inference Time: {inference_time:.2f} seconds") + print(f"Total tokens on page: {num_tokens:.2f}") print(f"Tokens/sec: {tokens_per_second:.2f}") print("") - print("Page predictions:") - print(page_tags) - page.predictions.doctags = DocTagsPrediction(tag_string=page_tags) yield page diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index 357499f6..26b8d2d8 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -1,7 +1,8 @@ import itertools import logging import re -from io import BytesIO + +# from io import BytesIO from pathlib import Path from typing import Optional @@ -45,6 +46,49 @@ class VlmPipeline(PaginatedPipeline): # force_backend_text = True - get text from backend using bounding boxes predicted by SmolDoclingss self.force_backend_text = pipeline_options.force_backend_text + ############################################### + # Tag definitions and color mappings + ############################################### + + # Maps the recognized tag to a Docling label. + # Code items will be given DocItemLabel.CODE + self.tag_to_doclabel = { + "title": DocItemLabel.TITLE, + "document_index": DocItemLabel.DOCUMENT_INDEX, + "otsl": DocItemLabel.TABLE, + "section_header_level_1": DocItemLabel.SECTION_HEADER, + "checkbox_selected": DocItemLabel.CHECKBOX_SELECTED, + "checkbox_unselected": DocItemLabel.CHECKBOX_UNSELECTED, + "text": DocItemLabel.TEXT, + "page_header": DocItemLabel.PAGE_HEADER, + "page_footer": DocItemLabel.PAGE_FOOTER, + "formula": DocItemLabel.FORMULA, + "caption": DocItemLabel.CAPTION, + "picture": DocItemLabel.PICTURE, + "list_item": DocItemLabel.LIST_ITEM, + "footnote": DocItemLabel.FOOTNOTE, + "code": DocItemLabel.CODE, + } + + # Maps each tag to an associated bounding box color. + self.tag_to_color = { + "title": "blue", + "document_index": "darkblue", + "otsl": "green", + "section_header_level_1": "purple", + "checkbox_selected": "black", + "checkbox_unselected": "gray", + "text": "red", + "page_header": "orange", + "page_footer": "cyan", + "formula": "pink", + "caption": "magenta", + "picture": "yellow", + "list_item": "brown", + "footnote": "darkred", + "code": "lightblue", + } + """ if pipeline_options.artifacts_path is None: self.artifacts_path = self.download_models_hf() @@ -136,6 +180,18 @@ class VlmPipeline(PaginatedPipeline): def _turn_tags_into_doc(self, pages: list[Page]) -> DoclingDocument: + def extract_bounding_box(text_chunk: str) -> Optional[BoundingBox]: + """Extracts bounding box coords from the chunk, normalized by / 500.""" + coords = re.findall(r"", text_chunk) + if len(coords) == 4: + l, t, r, b = map(float, coords) + return BoundingBox(l=l / 500, t=t / 500, r=r / 500, b=b / 500) + return None + + def extract_inner_text(text_chunk: str) -> str: + """Strips all <...> tags inside the chunk to get the raw text content.""" + return re.sub(r"<.*?>", "", text_chunk, flags=re.DOTALL).strip() + def extract_text_from_backend(page: Page, bbox: BoundingBox | None) -> str: # Convert bounding box normalized to 0-100 into page coordinates for cropping text = "" @@ -149,18 +205,7 @@ class VlmPipeline(PaginatedPipeline): text = page._backend.get_text_in_rect(bbox) return text - def extract_text(tag_content: str) -> str: - return re.sub(r"<.*?>", "", tag_content).strip() - - def extract_bounding_box(tag_content: str) -> Optional[BoundingBox]: - locs = re.findall(r"", tag_content) - if len(locs) == 4: - l, t, r, b = map(float, locs) - l, t, r, b = [coord / 500.0 for coord in (l, t, r, b)] - return BoundingBox(l=l, t=t, r=r, b=b) - return None - - def parse_texts(texts, tokens): + def otsl_parse_texts(texts, tokens): split_word = TableToken.OTSL_NL.value split_row_tokens = [ list(y) @@ -267,7 +312,7 @@ class VlmPipeline(PaginatedPipeline): c_idx = 0 return table_cells, split_row_tokens - def extract_tokens_and_text(s: str): + def otsl_extract_tokens_and_text(s: str): # Pattern to match anything enclosed by < > (including the angle brackets themselves) pattern = r"(<[^>]+>)" # Find all tokens (e.g. "", "", etc.) @@ -291,8 +336,8 @@ class VlmPipeline(PaginatedPipeline): return tokens, text_parts def parse_table_content(otsl_content: str) -> TableData: - tokens, mixed_texts = extract_tokens_and_text(otsl_content) - table_cells, split_row_tokens = parse_texts(mixed_texts, tokens) + tokens, mixed_texts = otsl_extract_tokens_and_text(otsl_content) + table_cells, split_row_tokens = otsl_parse_texts(mixed_texts, tokens) return TableData( num_rows=len(split_row_tokens), @@ -302,15 +347,18 @@ class VlmPipeline(PaginatedPipeline): table_cells=table_cells, ) - doc = DoclingDocument(name="Example Document") - current_group = None - + doc = DoclingDocument(name="Document") for pg_idx, page in enumerate(pages): xml_content = "" + predicted_text = "" if page.predictions.doctags: - xml_content = page.predictions.doctags.tag_string - pil_image = page.image + print("Doctags predicted for a page {}:".format(pg_idx)) + print(page.predictions.doctags) + print("") + predicted_text = page.predictions.doctags.tag_string + image = page.image page_no = pg_idx + 1 + bounding_boxes = [] if page.size: pg_width = page.size.width @@ -318,258 +366,84 @@ class VlmPipeline(PaginatedPipeline): size = Size(width=pg_width, height=pg_height) parent_page = doc.add_page(page_no=page_no, size=size) - lines = xml_content.split("\n") - bounding_boxes = [] + """ + 1. Finds all ... blocks in the entire string (multi-line friendly) in the order they appear. + 2. For each chunk, extracts bounding box (if any) and inner text. + 3. Adds the item to a DoclingDocument structure with the right label. + 4. Tracks bounding boxes + color in a separate list for later visualization. + """ - for line in lines: - line = line.strip() - line = line.replace("", "") - if line.startswith(""): - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) + # Regex for all recognized tags + tag_pattern = ( + r"<(?Ptitle|document_index|otsl|section_header_level_1|checkbox_selected|" + r"checkbox_unselected|text|page_header|page_footer|formula|caption|picture|" + r"list_item|footnote|code)>.*?" + ) + pattern = re.compile(tag_pattern, re.DOTALL) - if prov_item: - bounding_boxes.append((prov_item, "red")) - doc.add_text( - label=DocItemLabel.PARAGRAPH, - text=content, - parent=current_group, - prov=( - ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - if prov_item - else None - ), - ) - elif line.startswith(""): - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) + # Go through each match in order + for match in pattern.finditer(predicted_text): + full_chunk = match.group(0) + tag_name = match.group("tag") - if prov_item: - bounding_boxes.append((prov_item, "blue")) - current_group = doc.add_group( - label=GroupLabel.SECTION, name=content - ) - doc.add_text( - label=DocItemLabel.TITLE, - text=content, - parent=current_group, - prov=( - ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - if prov_item - else None - ), - ) + bbox = extract_bounding_box(full_chunk) + doc_label = self.tag_to_doclabel.get(tag_name, DocItemLabel.PARAGRAPH) + color = self.tag_to_color.get(tag_name, "white") - elif line.startswith("<section_header_level_1>"): - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) + # Store bounding box + color + if bbox: + bounding_boxes.append((bbox, color)) - if prov_item: - bounding_boxes.append((prov_item, "green")) - current_group = doc.add_group( - label=GroupLabel.SECTION, name=content - ) - doc.add_text( - label=DocItemLabel.SECTION_HEADER, - text=content, - parent=current_group, - prov=( - ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - if prov_item - else None - ), - ) + if tag_name == "otsl": + table_data = parse_table_content(full_chunk) + doc.add_table(data=table_data) - elif line.startswith("<otsl>"): - prov_item = extract_bounding_box(line) - if prov_item: - bounding_boxes.append((prov_item, "aquamarine")) - - table_data = parse_table_content(line) - doc.add_table(data=table_data, parent=current_group) - - elif line.startswith("<footnote>"): - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) - if prov_item: - bounding_boxes.append((prov_item, "orange")) - doc.add_text( - label=DocItemLabel.FOOTNOTE, - text=content, - parent=current_group, - prov=( - ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - if prov_item - else None - ), - ) - - elif line.startswith("<page_header>"): - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) - if prov_item: - bounding_boxes.append((prov_item, "purple")) - doc.add_text( - label=DocItemLabel.PAGE_HEADER, - text=content, - parent=current_group, - prov=( - ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - if prov_item - else None - ), - ) - - elif line.startswith("<page_footer>"): - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) - if prov_item: - bounding_boxes.append((prov_item, "cyan")) - doc.add_text( - label=DocItemLabel.PAGE_FOOTER, - text=content, - parent=current_group, - prov=( - ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - if prov_item - else None - ), - ) - - elif line.startswith("<picture>"): - bbox = extract_bounding_box(line) - if bbox: - bounding_boxes.append((bbox, "yellow")) - if pil_image: - # Convert bounding box normalized to 0-100 into pixel coordinates for cropping - width, height = pil_image.size + elif tag_name == "picture": + if image: + if bbox: + width, height = image.size crop_box = ( int(bbox.l * width), int(bbox.t * height), int(bbox.r * width), int(bbox.b * height), ) - - cropped_image = pil_image.crop(crop_box) + cropped_image = image.crop(crop_box) doc.add_picture( - parent=current_group, - image=ImageRef.from_pil(image=cropped_image, dpi=300), + parent=None, + image=ImageRef.from_pil(image=cropped_image, dpi=72), + prov=( + ProvenanceItem( + bbox=bbox, charspan=(0, 0), page_no=page_no + ) + ), + ) + else: + if bbox: + doc.add_picture( + parent=None, prov=ProvenanceItem( bbox=bbox, charspan=(0, 0), page_no=page_no ), ) - else: - doc.add_picture( - parent=current_group, - prov=ProvenanceItem( - bbox=bbox, charspan=(0, 0), page_no=page_no - ), - ) - elif line.startswith("<list_item>"): - prov_item_inst = None - prov_item = extract_bounding_box(line) + else: + # For everything else, treat as text if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) + content = extract_text_from_backend(page, bbox) else: - content = extract_text(line) - if prov_item: - bounding_boxes.append((prov_item, "brown")) - prov_item_inst = ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) + text_content = extract_inner_text(full_chunk) + # If it's code, wrap it with <pre><code> tags + if doc_label == DocItemLabel.CODE: + text_content = f"<pre><code>{text_content}</code></pre>" doc.add_text( - label=DocItemLabel.LIST_ITEM, - text=content, - parent=current_group, - prov=prov_item_inst if prov_item_inst else None, + label=doc_label, + text=text_content, + prov=( + ProvenanceItem(bbox=bbox, charspan=(0, 0), page_no=page_no) + if bbox + else None + ), ) - - elif line.startswith("<caption>"): - prov_item_inst = None - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) - if prov_item: - bounding_boxes.append((prov_item, "magenta")) - prov_item_inst = ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - doc.add_text( - label=DocItemLabel.PARAGRAPH, - text=content, - parent=current_group, - prov=prov_item_inst if prov_item_inst else None, - ) - elif line.startswith("<checkbox_unselected>"): - prov_item_inst = None - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) - if prov_item: - bounding_boxes.append((prov_item, "gray")) - prov_item_inst = ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - doc.add_text( - label=DocItemLabel.CHECKBOX_UNSELECTED, - text=content, - parent=current_group, - prov=prov_item_inst if prov_item_inst else None, - ) - - elif line.startswith("<checkbox_selected>"): - prov_item_inst = None - prov_item = extract_bounding_box(line) - if self.force_backend_text: - content = extract_text_from_backend(page, prov_item) - else: - content = extract_text(line) - if prov_item: - bounding_boxes.append((prov_item, "black")) - prov_item_inst = ProvenanceItem( - bbox=prov_item, charspan=(0, 0), page_no=page_no - ) - doc.add_text( - label=DocItemLabel.CHECKBOX_SELECTED, - text=content, - parent=current_group, - prov=prov_item_inst if prov_item_inst else None, - ) - # return doc, bounding_boxes return doc @classmethod