mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
Skeleton for SmolDocling model and VLM Pipeline
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
parent
1d17e7397a
commit
dc3a388aa2
@ -154,6 +154,10 @@ class LayoutPrediction(BaseModel):
|
|||||||
clusters: List[Cluster] = []
|
clusters: List[Cluster] = []
|
||||||
|
|
||||||
|
|
||||||
|
class DocTagsPrediction(BaseModel):
|
||||||
|
tag_string: str = ""
|
||||||
|
|
||||||
|
|
||||||
class ContainerElement(
|
class ContainerElement(
|
||||||
BasePageElement
|
BasePageElement
|
||||||
): # Used for Form and Key-Value-Regions, only for typing.
|
): # Used for Form and Key-Value-Regions, only for typing.
|
||||||
@ -197,6 +201,7 @@ class PagePredictions(BaseModel):
|
|||||||
tablestructure: Optional[TableStructurePrediction] = None
|
tablestructure: Optional[TableStructurePrediction] = None
|
||||||
figures_classification: Optional[FigureClassificationPrediction] = None
|
figures_classification: Optional[FigureClassificationPrediction] = None
|
||||||
equations_prediction: Optional[EquationPrediction] = None
|
equations_prediction: Optional[EquationPrediction] = None
|
||||||
|
doctags: Optional[DocTagsPrediction] = None
|
||||||
|
|
||||||
|
|
||||||
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
||||||
|
58
docling/models/smol_docling_model.py
Normal file
58
docling/models/smol_docling_model.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||||
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import (
|
||||||
|
BoundingBox,
|
||||||
|
Cell,
|
||||||
|
Cluster,
|
||||||
|
DocTagsPrediction,
|
||||||
|
LayoutPrediction,
|
||||||
|
Page,
|
||||||
|
)
|
||||||
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
|
||||||
|
from docling.datamodel.settings import settings
|
||||||
|
from docling.models.base_model import BasePageModel
|
||||||
|
from docling.utils.accelerator_utils import decide_device
|
||||||
|
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
||||||
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SmolDoclingModel(BasePageModel):
|
||||||
|
|
||||||
|
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||||
|
device = decide_device(accelerator_options.device)
|
||||||
|
|
||||||
|
# self.your_vlm_predictor(..., device) = None # TODO
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
) -> Iterable[Page]:
|
||||||
|
|
||||||
|
for page in page_batch:
|
||||||
|
assert page._backend is not None
|
||||||
|
if not page._backend.is_valid():
|
||||||
|
yield page
|
||||||
|
else:
|
||||||
|
with TimeRecorder(conv_res, "smolvlm"):
|
||||||
|
assert page.size is not None
|
||||||
|
|
||||||
|
hi_res_image = page.get_image(scale=2.0) # 144dpi
|
||||||
|
|
||||||
|
# Call your self.your_vlm_predictor with the page image as input (hi_res_image)
|
||||||
|
# populate page_tags
|
||||||
|
page_tags = ""
|
||||||
|
|
||||||
|
page.predictions.doctags = DocTagsPrediction(tag_string=page_tags)
|
||||||
|
|
||||||
|
yield page
|
139
docling/pipeline/vlm_pipeline.py
Normal file
139
docling/pipeline/vlm_pipeline.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from docling_core.types import DoclingDocument
|
||||||
|
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
|
||||||
|
|
||||||
|
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||||
|
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||||
|
from docling.datamodel.base_models import Page
|
||||||
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||||
|
from docling.models.smol_docling_model import SmolDoclingModel
|
||||||
|
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||||
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VlmPipeline(PaginatedPipeline):
|
||||||
|
_smol_vlm_path = "model_artifacts/smol_vlm" # TODO or whatever is needed.
|
||||||
|
|
||||||
|
def __init__(self, pipeline_options: PdfPipelineOptions):
|
||||||
|
super().__init__(pipeline_options)
|
||||||
|
self.pipeline_options: PdfPipelineOptions
|
||||||
|
|
||||||
|
if pipeline_options.artifacts_path is None:
|
||||||
|
self.artifacts_path = self.download_models_hf()
|
||||||
|
else:
|
||||||
|
self.artifacts_path = Path(pipeline_options.artifacts_path)
|
||||||
|
|
||||||
|
keep_images = (
|
||||||
|
self.pipeline_options.generate_page_images
|
||||||
|
or self.pipeline_options.generate_picture_images
|
||||||
|
or self.pipeline_options.generate_table_images
|
||||||
|
)
|
||||||
|
|
||||||
|
self.build_pipe = [
|
||||||
|
SmolDoclingModel(
|
||||||
|
artifacts_path=self.artifacts_path / VlmPipeline._smol_vlm_path,
|
||||||
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.enrichment_pipe = [
|
||||||
|
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download_models_hf(
|
||||||
|
local_dir: Optional[Path] = None, force: bool = False
|
||||||
|
) -> Path:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import disable_progress_bars
|
||||||
|
|
||||||
|
disable_progress_bars()
|
||||||
|
|
||||||
|
# TODO download the correct model (private repo)
|
||||||
|
download_path = snapshot_download(
|
||||||
|
repo_id="ds4sd/xxx",
|
||||||
|
force_download=force,
|
||||||
|
local_dir=local_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Path(download_path)
|
||||||
|
|
||||||
|
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
|
||||||
|
with TimeRecorder(conv_res, "page_init"):
|
||||||
|
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||||
|
if page._backend is not None and page._backend.is_valid():
|
||||||
|
page.size = page._backend.get_size()
|
||||||
|
|
||||||
|
return page
|
||||||
|
|
||||||
|
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
|
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
|
||||||
|
|
||||||
|
# Read and concatenate the page doctags:
|
||||||
|
document_tags = ""
|
||||||
|
for page in conv_res.pages:
|
||||||
|
if page.predictions.doctags is not None:
|
||||||
|
document_tags += page.predictions.doctags.tag_string
|
||||||
|
|
||||||
|
# TODO implement this function
|
||||||
|
conv_res.document = self._turn_tags_into_doc(document_tags)
|
||||||
|
|
||||||
|
# Generate page images in the output
|
||||||
|
if self.pipeline_options.generate_page_images:
|
||||||
|
for page in conv_res.pages:
|
||||||
|
assert page.image is not None
|
||||||
|
page_no = page.page_no + 1
|
||||||
|
conv_res.document.pages[page_no].image = ImageRef.from_pil(
|
||||||
|
page.image, dpi=int(72 * self.pipeline_options.images_scale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate images of the requested element types
|
||||||
|
if (
|
||||||
|
self.pipeline_options.generate_picture_images
|
||||||
|
or self.pipeline_options.generate_table_images
|
||||||
|
):
|
||||||
|
scale = self.pipeline_options.images_scale
|
||||||
|
for element, _level in conv_res.document.iterate_items():
|
||||||
|
if not isinstance(element, DocItem) or len(element.prov) == 0:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
isinstance(element, PictureItem)
|
||||||
|
and self.pipeline_options.generate_picture_images
|
||||||
|
) or (
|
||||||
|
isinstance(element, TableItem)
|
||||||
|
and self.pipeline_options.generate_table_images
|
||||||
|
):
|
||||||
|
page_ix = element.prov[0].page_no - 1
|
||||||
|
page = conv_res.pages[page_ix]
|
||||||
|
assert page.size is not None
|
||||||
|
assert page.image is not None
|
||||||
|
|
||||||
|
crop_bbox = (
|
||||||
|
element.prov[0]
|
||||||
|
.bbox.scaled(scale=scale)
|
||||||
|
.to_top_left_origin(page_height=page.size.height * scale)
|
||||||
|
)
|
||||||
|
|
||||||
|
cropped_im = page.image.crop(crop_bbox.as_tuple())
|
||||||
|
element.image = ImageRef.from_pil(
|
||||||
|
cropped_im, dpi=int(72 * scale)
|
||||||
|
)
|
||||||
|
|
||||||
|
return conv_res
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_options(cls) -> PdfPipelineOptions:
|
||||||
|
return PdfPipelineOptions()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_backend_supported(cls, backend: AbstractDocumentBackend):
|
||||||
|
return isinstance(backend, PdfDocumentBackend)
|
||||||
|
|
||||||
|
def _turn_tags_into_doc(self, document_tags):
|
||||||
|
return DoclingDocument()
|
13
docs/examples/minimal_smol_docling.py
Normal file
13
docs/examples/minimal_smol_docling.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from docling.datamodel.base_models import InputFormat
|
||||||
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||||
|
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||||
|
|
||||||
|
source = "https://arxiv.org/pdf/2408.09869" # document per local path or URL
|
||||||
|
converter = DocumentConverter(
|
||||||
|
doc_converter=DocumentConverter(
|
||||||
|
format_options={InputFormat.PDF: PdfFormatOption(pipeline_cls=VlmPipeline)}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = converter.convert(source)
|
||||||
|
print(result.document.export_to_markdown())
|
||||||
|
# output: ## Docling Technical Report [...]"
|
Loading…
Reference in New Issue
Block a user