mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-31 14:34:40 +00:00
Generalize and refactor VLM pipeline and models
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
1c75b52f85
commit
1cba96ecfd
@ -154,8 +154,8 @@ class LayoutPrediction(BaseModel):
|
||||
clusters: List[Cluster] = []
|
||||
|
||||
|
||||
class DocTagsPrediction(BaseModel):
|
||||
tag_string: str = ""
|
||||
class VlmPrediction(BaseModel):
|
||||
text: str = ""
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
@ -201,7 +201,7 @@ class PagePredictions(BaseModel):
|
||||
tablestructure: Optional[TableStructurePrediction] = None
|
||||
figures_classification: Optional[FigureClassificationPrediction] = None
|
||||
equations_prediction: Optional[EquationPrediction] = None
|
||||
doctags: Optional[DocTagsPrediction] = None
|
||||
vlm_response: Optional[VlmPrediction] = None
|
||||
|
||||
|
||||
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
||||
|
@ -254,12 +254,44 @@ granite_picture_description = PictureDescriptionVlmOptions(
|
||||
)
|
||||
|
||||
|
||||
class SmolDoclingOptions(BaseModel):
|
||||
question: str = "Convert this page to docling."
|
||||
class BaseVlmOptions(BaseModel):
|
||||
kind: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class ResponseFormat(str, Enum):
|
||||
DOCTAGS = "doctags"
|
||||
MARKDOWN = "markdown"
|
||||
|
||||
|
||||
class HuggingFaceVlmOptions(BaseVlmOptions):
|
||||
kind: Literal["hf_model_options"] = "hf_model_options"
|
||||
|
||||
repo_id: str
|
||||
load_in_8bit: bool = True
|
||||
llm_int8_threshold: float = 6.0
|
||||
quantized: bool = False
|
||||
|
||||
response_format: ResponseFormat
|
||||
|
||||
@property
|
||||
def repo_cache_folder(self) -> str:
|
||||
return self.repo_id.replace("/", "--")
|
||||
|
||||
|
||||
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ds4sd/SmolDocling-256M-preview",
|
||||
prompt="Convert this page to docling.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
)
|
||||
|
||||
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
||||
# prompt="OCR the full page to markdown.",
|
||||
prompt="OCR this image.",
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
)
|
||||
|
||||
|
||||
# Define an enum for the backend options
|
||||
class PdfBackend(str, Enum):
|
||||
@ -300,13 +332,11 @@ class PaginatedPipelineOptions(PipelineOptions):
|
||||
|
||||
class VlmPipelineOptions(PaginatedPipelineOptions):
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
do_vlm: bool = True # True: perform inference of Visual Language Model
|
||||
|
||||
force_backend_text: bool = (
|
||||
False # (To be used with vlms, or other generative models)
|
||||
)
|
||||
# If True, text from backend will be used instead of generated text
|
||||
vlm_options: Union[SmolDoclingOptions,] = Field(SmolDoclingOptions())
|
||||
vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options
|
||||
|
||||
|
||||
class PdfPipelineOptions(PaginatedPipelineOptions):
|
||||
@ -337,8 +367,6 @@ class PdfPipelineOptions(PaginatedPipelineOptions):
|
||||
Field(discriminator="kind"),
|
||||
] = smolvlm_picture_description
|
||||
|
||||
vlm_options: Union[SmolDoclingOptions,] = Field(SmolDoclingOptions())
|
||||
|
||||
images_scale: float = 1.0
|
||||
generate_page_images: bool = False
|
||||
generate_picture_images: bool = False
|
||||
|
@ -3,12 +3,14 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from docling.datamodel.base_models import DocTagsPrediction, Page
|
||||
from transformers import AutoModelForVision2Seq
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
SmolDoclingOptions,
|
||||
HuggingFaceVlmOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
@ -18,19 +20,19 @@ from docling.utils.profiling import TimeRecorder
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmolDoclingModel(BasePageModel):
|
||||
|
||||
_repo_id: str = "ds4sd/SmolDocling-256M-preview"
|
||||
class HuggingFaceVlmModel(BasePageModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Path],
|
||||
accelerator_options: AcceleratorOptions,
|
||||
vlm_options: SmolDoclingOptions,
|
||||
vlm_options: HuggingFaceVlmOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
if self.enabled:
|
||||
import torch
|
||||
from transformers import ( # type: ignore
|
||||
@ -42,17 +44,17 @@ class SmolDoclingModel(BasePageModel):
|
||||
device = decide_device(accelerator_options.device)
|
||||
self.device = device
|
||||
|
||||
_log.debug("Available device for SmolDocling: {}".format(device))
|
||||
_log.debug("Available device for HuggingFace VLM: {}".format(device))
|
||||
|
||||
repo_cache_folder = self._repo_id.replace("/", "--")
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
# PARAMETERS:
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models()
|
||||
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
self.param_question = vlm_options.question # "Perform Layout Analysis."
|
||||
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
||||
self.param_quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=vlm_options.load_in_8bit, # True,
|
||||
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
|
||||
@ -61,22 +63,27 @@ class SmolDoclingModel(BasePageModel):
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
||||
if not self.param_quantized:
|
||||
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
|
||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||
artifacts_path,
|
||||
# device_map=device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.vlm_model = self.vlm_model.to(device)
|
||||
# _attn_implementation=(
|
||||
# "flash_attention_2" if self.device.startswith("cuda") else "eager"
|
||||
# ),
|
||||
).to(self.device)
|
||||
|
||||
else:
|
||||
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
|
||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||
artifacts_path,
|
||||
# device_map=device,
|
||||
torch_dtype="auto",
|
||||
quantization_config=self.param_quantization_config,
|
||||
).to(device)
|
||||
# _attn_implementation=(
|
||||
# "flash_attention_2" if self.device.startswith("cuda") else "eager"
|
||||
# ),
|
||||
).to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
repo_id: str,
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
@ -87,7 +94,7 @@ class SmolDoclingModel(BasePageModel):
|
||||
if not progress:
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id=SmolDoclingModel._repo_id,
|
||||
repo_id=repo_id,
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
# revision="v0.0.1",
|
||||
@ -155,13 +162,13 @@ class SmolDoclingModel(BasePageModel):
|
||||
num_tokens = len(generated_ids[0])
|
||||
page_tags = generated_texts
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
tokens_per_second = num_tokens / generation_time
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
||||
# print("")
|
||||
page.predictions.doctags = DocTagsPrediction(tag_string=page_tags)
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||
|
||||
yield page
|
@ -2,6 +2,7 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
|
||||
# from io import BytesIO
|
||||
from pathlib import Path
|
||||
@ -26,12 +27,17 @@ from docling_core.types.doc import (
|
||||
from docling_core.types.doc.tokens import DocumentToken, TableToken
|
||||
|
||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.backend.md_backend import MarkdownDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions, VlmPipelineOptions
|
||||
from docling.datamodel.base_models import InputFormat, Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import (
|
||||
PdfPipelineOptions,
|
||||
ResponseFormat,
|
||||
VlmPipelineOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.smol_docling_model import SmolDoclingModel
|
||||
from docling.models.hf_vlm_model import HuggingFaceVlmModel
|
||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
@ -68,57 +74,14 @@ class VlmPipeline(PaginatedPipeline):
|
||||
# force_backend_text = True - get text from backend using bounding boxes predicted by SmolDoclingss
|
||||
self.force_backend_text = pipeline_options.force_backend_text
|
||||
|
||||
###############################################
|
||||
# Tag definitions and color mappings
|
||||
###############################################
|
||||
|
||||
# Maps the recognized tag to a Docling label.
|
||||
# Code items will be given DocItemLabel.CODE
|
||||
self.tag_to_doclabel = {
|
||||
"title": DocItemLabel.TITLE,
|
||||
"document_index": DocItemLabel.DOCUMENT_INDEX,
|
||||
"otsl": DocItemLabel.TABLE,
|
||||
"section_header_level_1": DocItemLabel.SECTION_HEADER,
|
||||
"checkbox_selected": DocItemLabel.CHECKBOX_SELECTED,
|
||||
"checkbox_unselected": DocItemLabel.CHECKBOX_UNSELECTED,
|
||||
"text": DocItemLabel.TEXT,
|
||||
"page_header": DocItemLabel.PAGE_HEADER,
|
||||
"page_footer": DocItemLabel.PAGE_FOOTER,
|
||||
"formula": DocItemLabel.FORMULA,
|
||||
"caption": DocItemLabel.CAPTION,
|
||||
"picture": DocItemLabel.PICTURE,
|
||||
"list_item": DocItemLabel.LIST_ITEM,
|
||||
"footnote": DocItemLabel.FOOTNOTE,
|
||||
"code": DocItemLabel.CODE,
|
||||
}
|
||||
|
||||
# Maps each tag to an associated bounding box color.
|
||||
self.tag_to_color = {
|
||||
"title": "blue",
|
||||
"document_index": "darkblue",
|
||||
"otsl": "green",
|
||||
"section_header_level_1": "purple",
|
||||
"checkbox_selected": "black",
|
||||
"checkbox_unselected": "gray",
|
||||
"text": "red",
|
||||
"page_header": "orange",
|
||||
"page_footer": "cyan",
|
||||
"formula": "pink",
|
||||
"caption": "magenta",
|
||||
"picture": "yellow",
|
||||
"list_item": "brown",
|
||||
"footnote": "darkred",
|
||||
"code": "lightblue",
|
||||
}
|
||||
|
||||
self.keep_images = (
|
||||
self.pipeline_options.generate_page_images
|
||||
or self.pipeline_options.generate_picture_images
|
||||
)
|
||||
|
||||
self.build_pipe = [
|
||||
SmolDoclingModel(
|
||||
enabled=pipeline_options.do_vlm,
|
||||
HuggingFaceVlmModel(
|
||||
enabled=True,
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
vlm_options=self.pipeline_options.vlm_options,
|
||||
@ -140,7 +103,21 @@ class VlmPipeline(PaginatedPipeline):
|
||||
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
|
||||
|
||||
conv_res.document = self._turn_tags_into_doc(conv_res.pages)
|
||||
if (
|
||||
self.pipeline_options.vlm_options.response_format
|
||||
== ResponseFormat.DOCTAGS
|
||||
):
|
||||
conv_res.document = self._turn_tags_into_doc(conv_res.pages)
|
||||
elif (
|
||||
self.pipeline_options.vlm_options.response_format
|
||||
== ResponseFormat.MARKDOWN
|
||||
):
|
||||
conv_res.document = self._turn_md_into_doc(conv_res)
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported VLM response format {self.pipeline_options.vlm_options.response_format}"
|
||||
)
|
||||
|
||||
# Generate images of the requested element types
|
||||
if self.pipeline_options.generate_picture_images:
|
||||
@ -170,7 +147,67 @@ class VlmPipeline(PaginatedPipeline):
|
||||
|
||||
return conv_res
|
||||
|
||||
def _turn_md_into_doc(self, conv_res):
|
||||
predicted_text = ""
|
||||
for pg_idx, page in enumerate(conv_res.pages):
|
||||
if page.predictions.vlm_response:
|
||||
predicted_text += page.predictions.vlm_response.text + "\n\n"
|
||||
response_bytes = BytesIO(predicted_text.encode("utf8"))
|
||||
out_doc = InputDocument(
|
||||
path_or_stream=response_bytes,
|
||||
filename=conv_res.input.file.name,
|
||||
format=InputFormat.MD,
|
||||
backend=MarkdownDocumentBackend,
|
||||
)
|
||||
backend = MarkdownDocumentBackend(
|
||||
in_doc=out_doc,
|
||||
path_or_stream=response_bytes,
|
||||
)
|
||||
return backend.convert()
|
||||
|
||||
def _turn_tags_into_doc(self, pages: list[Page]) -> DoclingDocument:
|
||||
###############################################
|
||||
# Tag definitions and color mappings
|
||||
###############################################
|
||||
|
||||
# Maps the recognized tag to a Docling label.
|
||||
# Code items will be given DocItemLabel.CODE
|
||||
tag_to_doclabel = {
|
||||
"title": DocItemLabel.TITLE,
|
||||
"document_index": DocItemLabel.DOCUMENT_INDEX,
|
||||
"otsl": DocItemLabel.TABLE,
|
||||
"section_header_level_1": DocItemLabel.SECTION_HEADER,
|
||||
"checkbox_selected": DocItemLabel.CHECKBOX_SELECTED,
|
||||
"checkbox_unselected": DocItemLabel.CHECKBOX_UNSELECTED,
|
||||
"text": DocItemLabel.TEXT,
|
||||
"page_header": DocItemLabel.PAGE_HEADER,
|
||||
"page_footer": DocItemLabel.PAGE_FOOTER,
|
||||
"formula": DocItemLabel.FORMULA,
|
||||
"caption": DocItemLabel.CAPTION,
|
||||
"picture": DocItemLabel.PICTURE,
|
||||
"list_item": DocItemLabel.LIST_ITEM,
|
||||
"footnote": DocItemLabel.FOOTNOTE,
|
||||
"code": DocItemLabel.CODE,
|
||||
}
|
||||
|
||||
# Maps each tag to an associated bounding box color.
|
||||
tag_to_color = {
|
||||
"title": "blue",
|
||||
"document_index": "darkblue",
|
||||
"otsl": "green",
|
||||
"section_header_level_1": "purple",
|
||||
"checkbox_selected": "black",
|
||||
"checkbox_unselected": "gray",
|
||||
"text": "red",
|
||||
"page_header": "orange",
|
||||
"page_footer": "cyan",
|
||||
"formula": "pink",
|
||||
"caption": "magenta",
|
||||
"picture": "yellow",
|
||||
"list_item": "brown",
|
||||
"footnote": "darkred",
|
||||
"code": "lightblue",
|
||||
}
|
||||
|
||||
def extract_bounding_box(text_chunk: str) -> Optional[BoundingBox]:
|
||||
"""Extracts <loc_...> bounding box coords from the chunk, normalized by / 500."""
|
||||
@ -357,8 +394,8 @@ class VlmPipeline(PaginatedPipeline):
|
||||
for pg_idx, page in enumerate(pages):
|
||||
xml_content = ""
|
||||
predicted_text = ""
|
||||
if page.predictions.doctags:
|
||||
predicted_text = page.predictions.doctags.tag_string
|
||||
if page.predictions.vlm_response:
|
||||
predicted_text = page.predictions.vlm_response.text
|
||||
image = page.image
|
||||
page_no = pg_idx + 1
|
||||
bounding_boxes = []
|
||||
@ -396,8 +433,8 @@ class VlmPipeline(PaginatedPipeline):
|
||||
tag_name = match.group("tag")
|
||||
|
||||
bbox = extract_bounding_box(full_chunk)
|
||||
doc_label = self.tag_to_doclabel.get(tag_name, DocItemLabel.PARAGRAPH)
|
||||
color = self.tag_to_color.get(tag_name, "white")
|
||||
doc_label = tag_to_doclabel.get(tag_name, DocItemLabel.PARAGRAPH)
|
||||
color = tag_to_color.get(tag_name, "white")
|
||||
|
||||
# Store bounding box + color
|
||||
if bbox:
|
||||
|
@ -5,7 +5,11 @@ from pathlib import Path
|
||||
import yaml
|
||||
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import SmolDoclingOptions, VlmPipelineOptions
|
||||
from docling.datamodel.pipeline_options import (
|
||||
VlmPipelineOptions,
|
||||
granite_vision_vlm_conversion_options,
|
||||
smoldocling_vlm_conversion_options,
|
||||
)
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||
|
||||
@ -19,16 +23,9 @@ pipeline_options = VlmPipelineOptions() # artifacts_path="~/local_model_artifac
|
||||
pipeline_options.generate_page_images = True
|
||||
# If force_backend_text = True, text from backend will be used instead of generated text
|
||||
pipeline_options.force_backend_text = False
|
||||
# pipeline_options.do_vlm = True - use False to disable VLM model (i.e. SmallDocling), extra python imports will not be performed
|
||||
|
||||
vlm_options = SmolDoclingOptions(
|
||||
# question="Convert this page to docling.",
|
||||
# load_in_8bit=True,
|
||||
# llm_int8_threshold=6.0,
|
||||
# quantized=False,
|
||||
)
|
||||
|
||||
pipeline_options.vlm_options = vlm_options
|
||||
# pipeline_options.vlm_options = smoldocling_vlm_conversion_options
|
||||
pipeline_options.vlm_options = granite_vision_vlm_conversion_options
|
||||
|
||||
from docling_core.types.doc import DocItemLabel, ImageRefMode
|
||||
from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS
|
||||
@ -67,7 +64,7 @@ for source in sources:
|
||||
for page in res.pages:
|
||||
print("")
|
||||
print("Predicted page in DOCTAGS:")
|
||||
print(page.predictions.doctags.tag_string)
|
||||
print(page.predictions.vlm_response.text)
|
||||
|
||||
res.document.save_as_html(
|
||||
filename=Path("{}/{}.html".format(out_path, res.input.file.stem)),
|
||||
|
Loading…
Reference in New Issue
Block a user