diff --git a/docling/models/smol_docling_model.py b/docling/models/smol_docling_model.py index 80c54f43..bcc7eadf 100644 --- a/docling/models/smol_docling_model.py +++ b/docling/models/smol_docling_model.py @@ -1,4 +1,5 @@ import logging +import time from pathlib import Path from typing import Iterable, List, Optional @@ -10,14 +11,7 @@ from transformers import ( # type: ignore Idefics3ForConditionalGeneration, ) -from docling.datamodel.base_models import ( - BoundingBox, - Cell, - Cluster, - DocTagsPrediction, - LayoutPrediction, - Page, -) +from docling.datamodel.base_models import DocTagsPrediction, Page from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.settings import settings @@ -31,7 +25,6 @@ _log = logging.getLogger(__name__) class SmolDoclingModel(BasePageModel): def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions): - print("SmolDocling, init...") device = decide_device(accelerator_options.device) self.device = device _log.info("Available device for SmolDocling: {}".format(device)) @@ -59,12 +52,10 @@ class SmolDoclingModel(BasePageModel): torch_dtype="auto", quantization_config=self.param_quantization_config, ) - print("SmolDocling, init... done!") def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: - print("SmolDocling, processing...") for page in page_batch: assert page._backend is not None if not page._backend.is_valid(): @@ -72,6 +63,7 @@ class SmolDoclingModel(BasePageModel): else: with TimeRecorder(conv_res, "smolvlm"): assert page.size is not None + start_time = time.time() hi_res_image = page.get_image(scale=2.0) # 144dpi # populate page_tags with predicted doc tags @@ -113,6 +105,9 @@ class SmolDoclingModel(BasePageModel): )[0] generated_texts = generated_texts.replace("Assistant: ", "") page_tags = generated_texts + + inference_time = time.time() - start_time + print(f"Page Inference Time: {inference_time:.2f} seconds") print("Page predictions:") print(page_tags) diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index a48b7ab0..afff0cef 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -16,10 +16,12 @@ from docling_core.types.doc import ( ImageRefMode, PictureItem, ProvenanceItem, + Size, TableCell, TableData, TableItem, ) +from docling_core.types.doc.tokens import DocumentToken, TableToken from PIL.Image import Image from docling.backend.abstract_backend import AbstractDocumentBackend @@ -39,7 +41,6 @@ class VlmPipeline(PaginatedPipeline): def __init__(self, pipeline_options: PdfPipelineOptions): super().__init__(pipeline_options) - print("------> Init VLM Pipeline!") self.pipeline_options: PdfPipelineOptions if pipeline_options.artifacts_path is None: @@ -91,7 +92,6 @@ class VlmPipeline(PaginatedPipeline): return page def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult: - print("VLM, assembling document...") with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT): # Read and concatenate the page doctags: @@ -202,7 +202,6 @@ class VlmPipeline(PaginatedPipeline): if not x ] table_cells = [] - # print("\nText parts:") r_idx = 0 c_idx = 0 @@ -227,7 +226,6 @@ class VlmPipeline(PaginatedPipeline): return span for i, text in enumerate(texts): - # print(f" {text}") cell_text = "" if text in ["", "", "", "", ""]: row_span = 1 @@ -323,8 +321,13 @@ class VlmPipeline(PaginatedPipeline): for pg_idx, xml_content in enumerate(full_doc_xml_content): pil_image = pil_images[pg_idx] page_no = pg_idx + 1 + + if pil_image: + pg_width, pg_height = pil_image.size + size = Size(width=pg_width, height=pg_height) + parent_page = doc.add_page(page_no=page_no, size=size) + lines = xml_content.split("\n") - # pil_image = input_image #Image.open(BytesIO(image_bytes)) bounding_boxes = [] for line in lines: diff --git a/docs/examples/minimal_smol_docling.py b/docs/examples/minimal_smol_docling.py index c9d9f9d0..50dbd0dc 100644 --- a/docs/examples/minimal_smol_docling.py +++ b/docs/examples/minimal_smol_docling.py @@ -1,4 +1,7 @@ +import os +import time from pathlib import Path +from urllib.parse import urlparse from docling.backend.docling_parse_backend import DoclingParseDocumentBackend from docling.datamodel.base_models import InputFormat @@ -6,12 +9,18 @@ from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.document_converter import DocumentConverter, PdfFormatOption from docling.pipeline.vlm_pipeline import VlmPipeline -source = "https://arxiv.org/pdf/2408.09869" # document per local path or URL +# source = "https://arxiv.org/pdf/2408.09869" # document per local path or URL # source = "tests/data/2305.03393v1-pg9-img.png" -# source = "tests/data/2305.03393v1-pg9.pdf" +source = "tests/data/2305.03393v1-pg9.pdf" # source = "demo_data/page.png" # source = "demo_data/original_tables.pdf" +parsed = urlparse(source) +if parsed.scheme in ("http", "https"): + out_name = os.path.basename(parsed.path) +else: + out_name = os.path.basename(source) + pipeline_options = PdfPipelineOptions() pipeline_options.generate_page_images = True pipeline_options.artifacts_path = "model_artifacts" @@ -32,6 +41,7 @@ converter = DocumentConverter( } ) +start_time = time.time() print("============") print("starting...") print("============") @@ -39,12 +49,6 @@ print("") result = converter.convert(source) -# print("------------") -# print("result:") -# print("------------") -# print("") -# print(result) - print("------------") print("MD:") print("------------") @@ -53,12 +57,16 @@ print(result.document.export_to_markdown()) Path("scratch").mkdir(parents=True, exist_ok=True) result.document.save_as_html( - filename=Path("scratch/smol_export.html"), + filename=Path("scratch/{}.html".format(out_name)), image_mode=ImageRefMode.REFERENCED, labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE], ) +pg_num = result.document.num_pages() + print("") +inference_time = time.time() - start_time +print(f"Total document prediction time: {inference_time:.2f} seconds, pages: {pg_num}") print("============") print("done!") print("============")