Updated vlm pipeline assembly and smol docling model code to support updated doctags

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
Maksym Lysak 2025-01-17 17:54:55 +01:00
parent f6d123a01c
commit 0fe12d819a
3 changed files with 29 additions and 14 deletions

View File

@ -97,16 +97,21 @@ class SmolDoclingModel(BasePageModel):
start_time = time.time() start_time = time.time()
# Call model to generate: # Call model to generate:
generated_ids = self.vlm_model.generate( generated_ids = self.vlm_model.generate(
**inputs, max_new_tokens=4096 **inputs, max_new_tokens=4096, use_cache=True
) )
generation_time = time.time() - start_time generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode( generated_texts = self.processor.batch_decode(
generated_ids, skip_special_tokens=True generated_ids, skip_special_tokens=False
)[0] )[0]
num_tokens = len(generated_ids[0]) num_tokens = len(generated_ids[0])
generated_texts = generated_texts.replace("Assistant: ", "") # DELETE NOISE BEFORE "Assistant: "
starting_point = "Assistant: "
generated_texts = generated_texts[
generated_texts.index(starting_point) + len(starting_point) :
]
# generated_texts = generated_texts.replace("Assistant: ", "")
page_tags = generated_texts page_tags = generated_texts
inference_time = time.time() - start_time inference_time = time.time() - start_time

View File

@ -36,7 +36,8 @@ _log = logging.getLogger(__name__)
class VlmPipeline(PaginatedPipeline): class VlmPipeline(PaginatedPipeline):
_smol_vlm_path = "SmolDocling-0.0.2" # _smol_vlm_path = "SmolDocling-0.0.2"
_smol_vlm_path = "SmolDocling_2.7_DT_0.7"
def __init__(self, pipeline_options: PdfPipelineOptions): def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options) super().__init__(pipeline_options)
@ -207,7 +208,9 @@ class VlmPipeline(PaginatedPipeline):
right_offset = 2 right_offset = 2
# Check next element(s) for lcel / ucel / xcel, set properly row_span, col_span # Check next element(s) for lcel / ucel / xcel, set properly row_span, col_span
next_right_cell = texts[i + right_offset] next_right_cell = ""
if i + right_offset < len(texts):
next_right_cell = texts[i + right_offset]
next_bottom_cell = "" next_bottom_cell = ""
if r_idx + 1 < len(split_row_tokens): if r_idx + 1 < len(split_row_tokens):
@ -367,7 +370,7 @@ class VlmPipeline(PaginatedPipeline):
), ),
) )
elif line.startswith("<section-header>"): elif line.startswith("<section_header_level_1>"):
prov_item = extract_bounding_box(line) prov_item = extract_bounding_box(line)
if self.force_backend_text: if self.force_backend_text:
content = extract_text_from_backend(page, prov_item) content = extract_text_from_backend(page, prov_item)
@ -421,7 +424,7 @@ class VlmPipeline(PaginatedPipeline):
), ),
) )
elif line.startswith("<page-header>"): elif line.startswith("<page_header>"):
prov_item = extract_bounding_box(line) prov_item = extract_bounding_box(line)
if self.force_backend_text: if self.force_backend_text:
content = extract_text_from_backend(page, prov_item) content = extract_text_from_backend(page, prov_item)
@ -442,7 +445,7 @@ class VlmPipeline(PaginatedPipeline):
), ),
) )
elif line.startswith("<page-footer>"): elif line.startswith("<page_footer>"):
prov_item = extract_bounding_box(line) prov_item = extract_bounding_box(line)
if self.force_backend_text: if self.force_backend_text:
content = extract_text_from_backend(page, prov_item) content = extract_text_from_backend(page, prov_item)
@ -463,7 +466,7 @@ class VlmPipeline(PaginatedPipeline):
), ),
) )
elif line.startswith("<figure>"): elif line.startswith("<picture>"):
bbox = extract_bounding_box(line) bbox = extract_bounding_box(line)
if bbox: if bbox:
bounding_boxes.append((bbox, "yellow")) bounding_boxes.append((bbox, "yellow"))
@ -492,7 +495,7 @@ class VlmPipeline(PaginatedPipeline):
bbox=bbox, charspan=(0, 0), page_no=page_no bbox=bbox, charspan=(0, 0), page_no=page_no
), ),
) )
elif line.startswith("<list>"): elif line.startswith("<list_item>"):
prov_item_inst = None prov_item_inst = None
prov_item = extract_bounding_box(line) prov_item = extract_bounding_box(line)
if self.force_backend_text: if self.force_backend_text:
@ -529,7 +532,7 @@ class VlmPipeline(PaginatedPipeline):
parent=current_group, parent=current_group,
prov=prov_item_inst if prov_item_inst else None, prov=prov_item_inst if prov_item_inst else None,
) )
elif line.startswith("<checkbox-unselected>"): elif line.startswith("<checkbox_unselected>"):
prov_item_inst = None prov_item_inst = None
prov_item = extract_bounding_box(line) prov_item = extract_bounding_box(line)
if self.force_backend_text: if self.force_backend_text:
@ -548,7 +551,7 @@ class VlmPipeline(PaginatedPipeline):
prov=prov_item_inst if prov_item_inst else None, prov=prov_item_inst if prov_item_inst else None,
) )
elif line.startswith("<checkbox-selected>"): elif line.startswith("<checkbox_selected>"):
prov_item_inst = None prov_item_inst = None
prov_item = extract_bounding_box(line) prov_item = extract_bounding_box(line)
if self.force_backend_text: if self.force_backend_text:

View File

@ -13,6 +13,7 @@ from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline from docling.pipeline.vlm_pipeline import VlmPipeline
sources = [ sources = [
# "https://arxiv.org/pdf/2408.09869"
# "tests/data/2305.03393v1-pg9-img.png", # "tests/data/2305.03393v1-pg9-img.png",
"tests/data/2305.03393v1-pg9.pdf", "tests/data/2305.03393v1-pg9.pdf",
# "demo_data/page.png", # "demo_data/page.png",
@ -60,8 +61,14 @@ for source in sources:
print("") print("")
print(res.document.export_to_markdown()) print(res.document.export_to_markdown())
with (out_path / f"{res.input.file.stem}.html").open("w") as fp: # with (out_path / f"{res.input.file.stem}.html").open("w") as fp:
fp.write(res.document.export_to_html()) # fp.write(res.document.export_to_html())
res.document.save_as_html(
filename=Path("{}/{}.html".format(out_path, res.input.file.stem)),
image_mode=ImageRefMode.REFERENCED,
labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE],
)
with (out_path / f"{res.input.file.stem}.json").open("w") as fp: with (out_path / f"{res.input.file.stem}.json").open("w") as fp:
fp.write(json.dumps(res.document.export_to_dict())) fp.write(json.dumps(res.document.export_to_dict()))