mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
Skeleton for SmolDocling model and VLM Pipeline
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
parent
1d17e7397a
commit
dc3a388aa2
@ -154,6 +154,10 @@ class LayoutPrediction(BaseModel):
|
||||
clusters: List[Cluster] = []
|
||||
|
||||
|
||||
class DocTagsPrediction(BaseModel):
|
||||
tag_string: str = ""
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
BasePageElement
|
||||
): # Used for Form and Key-Value-Regions, only for typing.
|
||||
@ -197,6 +201,7 @@ class PagePredictions(BaseModel):
|
||||
tablestructure: Optional[TableStructurePrediction] = None
|
||||
figures_classification: Optional[FigureClassificationPrediction] = None
|
||||
equations_prediction: Optional[EquationPrediction] = None
|
||||
doctags: Optional[DocTagsPrediction] = None
|
||||
|
||||
|
||||
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
||||
|
58
docling/models/smol_docling_model.py
Normal file
58
docling/models/smol_docling_model.py
Normal file
@ -0,0 +1,58 @@
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Cell,
|
||||
Cluster,
|
||||
DocTagsPrediction,
|
||||
LayoutPrediction,
|
||||
Page,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmolDoclingModel(BasePageModel):
|
||||
|
||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
# self.your_vlm_predictor(..., device) = None # TODO
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
with TimeRecorder(conv_res, "smolvlm"):
|
||||
assert page.size is not None
|
||||
|
||||
hi_res_image = page.get_image(scale=2.0) # 144dpi
|
||||
|
||||
# Call your self.your_vlm_predictor with the page image as input (hi_res_image)
|
||||
# populate page_tags
|
||||
page_tags = ""
|
||||
|
||||
page.predictions.doctags = DocTagsPrediction(tag_string=page_tags)
|
||||
|
||||
yield page
|
139
docling/pipeline/vlm_pipeline.py
Normal file
139
docling/pipeline/vlm_pipeline.py
Normal file
@ -0,0 +1,139 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from docling_core.types import DoclingDocument
|
||||
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
|
||||
|
||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
from docling.models.smol_docling_model import SmolDoclingModel
|
||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VlmPipeline(PaginatedPipeline):
|
||||
_smol_vlm_path = "model_artifacts/smol_vlm" # TODO or whatever is needed.
|
||||
|
||||
def __init__(self, pipeline_options: PdfPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.pipeline_options: PdfPipelineOptions
|
||||
|
||||
if pipeline_options.artifacts_path is None:
|
||||
self.artifacts_path = self.download_models_hf()
|
||||
else:
|
||||
self.artifacts_path = Path(pipeline_options.artifacts_path)
|
||||
|
||||
keep_images = (
|
||||
self.pipeline_options.generate_page_images
|
||||
or self.pipeline_options.generate_picture_images
|
||||
or self.pipeline_options.generate_table_images
|
||||
)
|
||||
|
||||
self.build_pipe = [
|
||||
SmolDoclingModel(
|
||||
artifacts_path=self.artifacts_path / VlmPipeline._smol_vlm_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
]
|
||||
|
||||
self.enrichment_pipe = [
|
||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
|
||||
# TODO download the correct model (private repo)
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/xxx",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
|
||||
with TimeRecorder(conv_res, "page_init"):
|
||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||
if page._backend is not None and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
|
||||
return page
|
||||
|
||||
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
|
||||
|
||||
# Read and concatenate the page doctags:
|
||||
document_tags = ""
|
||||
for page in conv_res.pages:
|
||||
if page.predictions.doctags is not None:
|
||||
document_tags += page.predictions.doctags.tag_string
|
||||
|
||||
# TODO implement this function
|
||||
conv_res.document = self._turn_tags_into_doc(document_tags)
|
||||
|
||||
# Generate page images in the output
|
||||
if self.pipeline_options.generate_page_images:
|
||||
for page in conv_res.pages:
|
||||
assert page.image is not None
|
||||
page_no = page.page_no + 1
|
||||
conv_res.document.pages[page_no].image = ImageRef.from_pil(
|
||||
page.image, dpi=int(72 * self.pipeline_options.images_scale)
|
||||
)
|
||||
|
||||
# Generate images of the requested element types
|
||||
if (
|
||||
self.pipeline_options.generate_picture_images
|
||||
or self.pipeline_options.generate_table_images
|
||||
):
|
||||
scale = self.pipeline_options.images_scale
|
||||
for element, _level in conv_res.document.iterate_items():
|
||||
if not isinstance(element, DocItem) or len(element.prov) == 0:
|
||||
continue
|
||||
if (
|
||||
isinstance(element, PictureItem)
|
||||
and self.pipeline_options.generate_picture_images
|
||||
) or (
|
||||
isinstance(element, TableItem)
|
||||
and self.pipeline_options.generate_table_images
|
||||
):
|
||||
page_ix = element.prov[0].page_no - 1
|
||||
page = conv_res.pages[page_ix]
|
||||
assert page.size is not None
|
||||
assert page.image is not None
|
||||
|
||||
crop_bbox = (
|
||||
element.prov[0]
|
||||
.bbox.scaled(scale=scale)
|
||||
.to_top_left_origin(page_height=page.size.height * scale)
|
||||
)
|
||||
|
||||
cropped_im = page.image.crop(crop_bbox.as_tuple())
|
||||
element.image = ImageRef.from_pil(
|
||||
cropped_im, dpi=int(72 * scale)
|
||||
)
|
||||
|
||||
return conv_res
|
||||
|
||||
@classmethod
|
||||
def get_default_options(cls) -> PdfPipelineOptions:
|
||||
return PdfPipelineOptions()
|
||||
|
||||
@classmethod
|
||||
def is_backend_supported(cls, backend: AbstractDocumentBackend):
|
||||
return isinstance(backend, PdfDocumentBackend)
|
||||
|
||||
def _turn_tags_into_doc(self, document_tags):
|
||||
return DoclingDocument()
|
13
docs/examples/minimal_smol_docling.py
Normal file
13
docs/examples/minimal_smol_docling.py
Normal file
@ -0,0 +1,13 @@
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||
|
||||
source = "https://arxiv.org/pdf/2408.09869" # document per local path or URL
|
||||
converter = DocumentConverter(
|
||||
doc_converter=DocumentConverter(
|
||||
format_options={InputFormat.PDF: PdfFormatOption(pipeline_cls=VlmPipeline)}
|
||||
)
|
||||
)
|
||||
result = converter.convert(source)
|
||||
print(result.document.export_to_markdown())
|
||||
# output: ## Docling Technical Report [...]"
|
Loading…
Reference in New Issue
Block a user