mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 20:58:11 +00:00
feat: [Experimental] Introduce VLM pipeline using HF AutoModelForVision2Seq, featuring SmolDocling model (#1054)
* Skeleton for SmolDocling model and VLM Pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * wip smolDocling inference and vlm pipeline Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * WIP, first working code for inference of SmolDocling, and vlm pipeline assembly code, example included. Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Fixes to preserve page image and demo export to html Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Enabled figure support in vlm_pipeline Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Fix for table span compute in vlm_pipeline Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Properly propagating image data per page, together with predicted tags in VLM pipeline. This enables correct figure extraction and page numbers in provenances Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Cleaned up logs, added pages to vlm_pipeline, basic timing per page measurement in smol_docling models Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Replaced hardcoded otsl tokens with the ones from docling-core tokens.py enum Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Added tokens/sec measurement, improved example Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Added capability for vlm_pipeline to grab text from preconfigured backend Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Exposed "force_backend_text" as pipeline parameter Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Flipped keep_backend to True for vlm_pipeline assembly to work Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Updated vlm pipeline assembly and smol docling model code to support updated doctags Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Fixing doctags starting tag, that broke elements on first line during assembly Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Introduced SmolDoclingOptions to configure model parameters (such as query and artifacts path) via client code, see example in minimal_smol_docling. Provisioning for other potential vlm all-in-one models. Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Moved artifacts_path for SmolDocling into vlm_options instead of global pipeline option Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * New assembly code for latest model revision, updated prompt and parsing of doctags, updated logging Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Updated example of Smol Docling usage Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Added captions for the images for SmolDocling assembly code, improved provenance definition for all elements Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Update minimal smoldocling example Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix repo id Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Cleaned up unnecessary logging Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * More elegant solution in removing the input prompt Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * removed minimal_smol_docling example from CI checks Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Removed special html code wrapping when exporting to docling document, cleaned up comments Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Addressing PR comments, added enabled property to SmolDocling, and related VLM pipeline option, few other minor things Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Moved keep_backend = True to vlm pipeline Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * removed pipeline_options.generate_table_images from vlm_pipeline (deprecated in the pipelines) Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Added example on how to get original predicted doctags in minimal_smol_docling Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * removing changes from base_pipeline Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Replaced remaining strings to appropriate enums Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Updated poetry.lock Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * re-built poetry.lock Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Generalize and refactor VLM pipeline and models Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Rename example Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Move imports Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Expose control over using flash_attention_2 Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix VLM example exclusion in CI Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add back device_map and accelerate Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Make drawing code resilient against bad bboxes Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * chore: clean up code and comments Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * chore: more cleanup Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * chore: fix leftover .to(device) Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * fix: add proper table provenance Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> Co-authored-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
@@ -154,6 +154,10 @@ class LayoutPrediction(BaseModel):
|
||||
clusters: List[Cluster] = []
|
||||
|
||||
|
||||
class VlmPrediction(BaseModel):
|
||||
text: str = ""
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
BasePageElement
|
||||
): # Used for Form and Key-Value-Regions, only for typing.
|
||||
@@ -197,6 +201,7 @@ class PagePredictions(BaseModel):
|
||||
tablestructure: Optional[TableStructurePrediction] = None
|
||||
figures_classification: Optional[FigureClassificationPrediction] = None
|
||||
equations_prediction: Optional[EquationPrediction] = None
|
||||
vlm_response: Optional[VlmPrediction] = None
|
||||
|
||||
|
||||
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
||||
|
||||
@@ -41,6 +41,7 @@ class AcceleratorOptions(BaseSettings):
|
||||
|
||||
num_threads: int = 4
|
||||
device: Union[str, AcceleratorDevice] = "auto"
|
||||
cuda_use_flash_attention2: bool = False
|
||||
|
||||
@field_validator("device")
|
||||
def validate_device(cls, value):
|
||||
@@ -254,6 +255,45 @@ granite_picture_description = PictureDescriptionVlmOptions(
|
||||
)
|
||||
|
||||
|
||||
class BaseVlmOptions(BaseModel):
|
||||
kind: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class ResponseFormat(str, Enum):
|
||||
DOCTAGS = "doctags"
|
||||
MARKDOWN = "markdown"
|
||||
|
||||
|
||||
class HuggingFaceVlmOptions(BaseVlmOptions):
|
||||
kind: Literal["hf_model_options"] = "hf_model_options"
|
||||
|
||||
repo_id: str
|
||||
load_in_8bit: bool = True
|
||||
llm_int8_threshold: float = 6.0
|
||||
quantized: bool = False
|
||||
|
||||
response_format: ResponseFormat
|
||||
|
||||
@property
|
||||
def repo_cache_folder(self) -> str:
|
||||
return self.repo_id.replace("/", "--")
|
||||
|
||||
|
||||
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ds4sd/SmolDocling-256M-preview",
|
||||
prompt="Convert this page to docling.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
)
|
||||
|
||||
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
||||
# prompt="OCR the full page to markdown.",
|
||||
prompt="OCR this image.",
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
)
|
||||
|
||||
|
||||
# Define an enum for the backend options
|
||||
class PdfBackend(str, Enum):
|
||||
"""Enum of valid PDF backends."""
|
||||
@@ -285,7 +325,24 @@ class PipelineOptions(BaseModel):
|
||||
enable_remote_services: bool = False
|
||||
|
||||
|
||||
class PdfPipelineOptions(PipelineOptions):
|
||||
class PaginatedPipelineOptions(PipelineOptions):
|
||||
images_scale: float = 1.0
|
||||
generate_page_images: bool = False
|
||||
generate_picture_images: bool = False
|
||||
|
||||
|
||||
class VlmPipelineOptions(PaginatedPipelineOptions):
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
|
||||
generate_page_images: bool = True
|
||||
force_backend_text: bool = (
|
||||
False # (To be used with vlms, or other generative models)
|
||||
)
|
||||
# If True, text from backend will be used instead of generated text
|
||||
vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options
|
||||
|
||||
|
||||
class PdfPipelineOptions(PaginatedPipelineOptions):
|
||||
"""Options for the PDF pipeline."""
|
||||
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
@@ -295,6 +352,10 @@ class PdfPipelineOptions(PipelineOptions):
|
||||
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
|
||||
do_picture_classification: bool = False # True: classify pictures in documents
|
||||
do_picture_description: bool = False # True: run describe pictures in documents
|
||||
force_backend_text: bool = (
|
||||
False # (To be used with vlms, or other generative models)
|
||||
)
|
||||
# If True, text from backend will be used instead of generated text
|
||||
|
||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||
ocr_options: Union[
|
||||
|
||||
180
docling/models/hf_vlm_model.py
Normal file
180
docling/models/hf_vlm_model.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
HuggingFaceVlmOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceVlmModel(BasePageModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Path],
|
||||
accelerator_options: AcceleratorOptions,
|
||||
vlm_options: HuggingFaceVlmOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
if self.enabled:
|
||||
import torch
|
||||
from transformers import ( # type: ignore
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
device = decide_device(accelerator_options.device)
|
||||
self.device = device
|
||||
|
||||
_log.debug("Available device for HuggingFace VLM: {}".format(device))
|
||||
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
# PARAMETERS:
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
||||
self.param_quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=vlm_options.load_in_8bit, # True,
|
||||
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
|
||||
)
|
||||
self.param_quantized = vlm_options.quantized # False
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
||||
if not self.param_quantized:
|
||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
and accelerator_options.cuda_use_flash_attention2
|
||||
else "eager"
|
||||
),
|
||||
) # .to(self.device)
|
||||
|
||||
else:
|
||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=device,
|
||||
torch_dtype="auto",
|
||||
quantization_config=self.param_quantization_config,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
and accelerator_options.cuda_use_flash_attention2
|
||||
else "eager"
|
||||
),
|
||||
) # .to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
repo_id: str,
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
if not progress:
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id=repo_id,
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
# revision="v0.0.1",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
with TimeRecorder(conv_res, "vlm"):
|
||||
assert page.size is not None
|
||||
|
||||
hi_res_image = page.get_image(scale=2.0) # 144dpi
|
||||
# hi_res_image = page.get_image(scale=1.0) # 72dpi
|
||||
|
||||
if hi_res_image is not None:
|
||||
im_width, im_height = hi_res_image.size
|
||||
|
||||
# populate page_tags with predicted doc tags
|
||||
page_tags = ""
|
||||
|
||||
if hi_res_image:
|
||||
if hi_res_image.mode != "RGB":
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": self.param_question},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=False
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
||||
)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
start_time = time.time()
|
||||
# Call model to generate:
|
||||
generated_ids = self.vlm_model.generate(
|
||||
**inputs, max_new_tokens=4096, use_cache=True
|
||||
)
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
generated_texts = self.processor.batch_decode(
|
||||
generated_ids[:, inputs["input_ids"].shape[1] :],
|
||||
skip_special_tokens=False,
|
||||
)[0]
|
||||
|
||||
num_tokens = len(generated_ids[0])
|
||||
page_tags = generated_texts
|
||||
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
||||
# print("")
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||
|
||||
yield page
|
||||
534
docling/pipeline/vlm_pipeline.py
Normal file
534
docling/pipeline/vlm_pipeline.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
|
||||
# from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from docling_core.types import DoclingDocument
|
||||
from docling_core.types.doc import (
|
||||
BoundingBox,
|
||||
DocItem,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
GroupLabel,
|
||||
ImageRef,
|
||||
ImageRefMode,
|
||||
PictureItem,
|
||||
ProvenanceItem,
|
||||
Size,
|
||||
TableCell,
|
||||
TableData,
|
||||
TableItem,
|
||||
)
|
||||
from docling_core.types.doc.tokens import DocumentToken, TableToken
|
||||
|
||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.backend.md_backend import MarkdownDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat, Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import (
|
||||
PdfPipelineOptions,
|
||||
ResponseFormat,
|
||||
VlmPipelineOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.hf_vlm_model import HuggingFaceVlmModel
|
||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VlmPipeline(PaginatedPipeline):
|
||||
|
||||
def __init__(self, pipeline_options: VlmPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.keep_backend = True
|
||||
|
||||
warnings.warn(
|
||||
"The VlmPipeline is currently experimental and may change in upcoming versions without notice.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.pipeline_options: VlmPipelineOptions
|
||||
|
||||
artifacts_path: Optional[Path] = None
|
||||
if pipeline_options.artifacts_path is not None:
|
||||
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
|
||||
elif settings.artifacts_path is not None:
|
||||
artifacts_path = Path(settings.artifacts_path).expanduser()
|
||||
|
||||
if artifacts_path is not None and not artifacts_path.is_dir():
|
||||
raise RuntimeError(
|
||||
f"The value of {artifacts_path=} is not valid. "
|
||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||
)
|
||||
|
||||
# force_backend_text = False - use text that is coming from VLM response
|
||||
# force_backend_text = True - get text from backend using bounding boxes predicted by SmolDocling doctags
|
||||
self.force_backend_text = (
|
||||
pipeline_options.force_backend_text
|
||||
and pipeline_options.vlm_options.response_format == ResponseFormat.DOCTAGS
|
||||
)
|
||||
|
||||
self.keep_images = self.pipeline_options.generate_page_images
|
||||
|
||||
self.build_pipe = [
|
||||
HuggingFaceVlmModel(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
vlm_options=self.pipeline_options.vlm_options,
|
||||
),
|
||||
]
|
||||
|
||||
self.enrichment_pipe = [
|
||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||
]
|
||||
|
||||
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
|
||||
with TimeRecorder(conv_res, "page_init"):
|
||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||
if page._backend is not None and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
|
||||
return page
|
||||
|
||||
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
|
||||
|
||||
if (
|
||||
self.pipeline_options.vlm_options.response_format
|
||||
== ResponseFormat.DOCTAGS
|
||||
):
|
||||
conv_res.document = self._turn_tags_into_doc(conv_res.pages)
|
||||
elif (
|
||||
self.pipeline_options.vlm_options.response_format
|
||||
== ResponseFormat.MARKDOWN
|
||||
):
|
||||
conv_res.document = self._turn_md_into_doc(conv_res)
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported VLM response format {self.pipeline_options.vlm_options.response_format}"
|
||||
)
|
||||
|
||||
# Generate images of the requested element types
|
||||
if self.pipeline_options.generate_picture_images:
|
||||
scale = self.pipeline_options.images_scale
|
||||
for element, _level in conv_res.document.iterate_items():
|
||||
if not isinstance(element, DocItem) or len(element.prov) == 0:
|
||||
continue
|
||||
if (
|
||||
isinstance(element, PictureItem)
|
||||
and self.pipeline_options.generate_picture_images
|
||||
):
|
||||
page_ix = element.prov[0].page_no - 1
|
||||
page = conv_res.pages[page_ix]
|
||||
assert page.size is not None
|
||||
assert page.image is not None
|
||||
|
||||
crop_bbox = (
|
||||
element.prov[0]
|
||||
.bbox.scaled(scale=scale)
|
||||
.to_top_left_origin(page_height=page.size.height * scale)
|
||||
)
|
||||
|
||||
cropped_im = page.image.crop(crop_bbox.as_tuple())
|
||||
element.image = ImageRef.from_pil(
|
||||
cropped_im, dpi=int(72 * scale)
|
||||
)
|
||||
|
||||
return conv_res
|
||||
|
||||
def _turn_md_into_doc(self, conv_res):
|
||||
predicted_text = ""
|
||||
for pg_idx, page in enumerate(conv_res.pages):
|
||||
if page.predictions.vlm_response:
|
||||
predicted_text += page.predictions.vlm_response.text + "\n\n"
|
||||
response_bytes = BytesIO(predicted_text.encode("utf8"))
|
||||
out_doc = InputDocument(
|
||||
path_or_stream=response_bytes,
|
||||
filename=conv_res.input.file.name,
|
||||
format=InputFormat.MD,
|
||||
backend=MarkdownDocumentBackend,
|
||||
)
|
||||
backend = MarkdownDocumentBackend(
|
||||
in_doc=out_doc,
|
||||
path_or_stream=response_bytes,
|
||||
)
|
||||
return backend.convert()
|
||||
|
||||
def _turn_tags_into_doc(self, pages: list[Page]) -> DoclingDocument:
|
||||
###############################################
|
||||
# Tag definitions and color mappings
|
||||
###############################################
|
||||
|
||||
# Maps the recognized tag to a Docling label.
|
||||
# Code items will be given DocItemLabel.CODE
|
||||
tag_to_doclabel = {
|
||||
"title": DocItemLabel.TITLE,
|
||||
"document_index": DocItemLabel.DOCUMENT_INDEX,
|
||||
"otsl": DocItemLabel.TABLE,
|
||||
"section_header_level_1": DocItemLabel.SECTION_HEADER,
|
||||
"checkbox_selected": DocItemLabel.CHECKBOX_SELECTED,
|
||||
"checkbox_unselected": DocItemLabel.CHECKBOX_UNSELECTED,
|
||||
"text": DocItemLabel.TEXT,
|
||||
"page_header": DocItemLabel.PAGE_HEADER,
|
||||
"page_footer": DocItemLabel.PAGE_FOOTER,
|
||||
"formula": DocItemLabel.FORMULA,
|
||||
"caption": DocItemLabel.CAPTION,
|
||||
"picture": DocItemLabel.PICTURE,
|
||||
"list_item": DocItemLabel.LIST_ITEM,
|
||||
"footnote": DocItemLabel.FOOTNOTE,
|
||||
"code": DocItemLabel.CODE,
|
||||
}
|
||||
|
||||
# Maps each tag to an associated bounding box color.
|
||||
tag_to_color = {
|
||||
"title": "blue",
|
||||
"document_index": "darkblue",
|
||||
"otsl": "green",
|
||||
"section_header_level_1": "purple",
|
||||
"checkbox_selected": "black",
|
||||
"checkbox_unselected": "gray",
|
||||
"text": "red",
|
||||
"page_header": "orange",
|
||||
"page_footer": "cyan",
|
||||
"formula": "pink",
|
||||
"caption": "magenta",
|
||||
"picture": "yellow",
|
||||
"list_item": "brown",
|
||||
"footnote": "darkred",
|
||||
"code": "lightblue",
|
||||
}
|
||||
|
||||
def extract_bounding_box(text_chunk: str) -> Optional[BoundingBox]:
|
||||
"""Extracts <loc_...> bounding box coords from the chunk, normalized by / 500."""
|
||||
coords = re.findall(r"<loc_(\d+)>", text_chunk)
|
||||
if len(coords) == 4:
|
||||
l, t, r, b = map(float, coords)
|
||||
return BoundingBox(l=l / 500, t=t / 500, r=r / 500, b=b / 500)
|
||||
return None
|
||||
|
||||
def extract_inner_text(text_chunk: str) -> str:
|
||||
"""Strips all <...> tags inside the chunk to get the raw text content."""
|
||||
return re.sub(r"<.*?>", "", text_chunk, flags=re.DOTALL).strip()
|
||||
|
||||
def extract_text_from_backend(page: Page, bbox: BoundingBox | None) -> str:
|
||||
# Convert bounding box normalized to 0-100 into page coordinates for cropping
|
||||
text = ""
|
||||
if bbox:
|
||||
if page.size:
|
||||
bbox.l = bbox.l * page.size.width
|
||||
bbox.t = bbox.t * page.size.height
|
||||
bbox.r = bbox.r * page.size.width
|
||||
bbox.b = bbox.b * page.size.height
|
||||
if page._backend:
|
||||
text = page._backend.get_text_in_rect(bbox)
|
||||
return text
|
||||
|
||||
def otsl_parse_texts(texts, tokens):
|
||||
split_word = TableToken.OTSL_NL.value
|
||||
split_row_tokens = [
|
||||
list(y)
|
||||
for x, y in itertools.groupby(tokens, lambda z: z == split_word)
|
||||
if not x
|
||||
]
|
||||
table_cells = []
|
||||
r_idx = 0
|
||||
c_idx = 0
|
||||
|
||||
def count_right(tokens, c_idx, r_idx, which_tokens):
|
||||
span = 0
|
||||
c_idx_iter = c_idx
|
||||
while tokens[r_idx][c_idx_iter] in which_tokens:
|
||||
c_idx_iter += 1
|
||||
span += 1
|
||||
if c_idx_iter >= len(tokens[r_idx]):
|
||||
return span
|
||||
return span
|
||||
|
||||
def count_down(tokens, c_idx, r_idx, which_tokens):
|
||||
span = 0
|
||||
r_idx_iter = r_idx
|
||||
while tokens[r_idx_iter][c_idx] in which_tokens:
|
||||
r_idx_iter += 1
|
||||
span += 1
|
||||
if r_idx_iter >= len(tokens):
|
||||
return span
|
||||
return span
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
cell_text = ""
|
||||
if text in [
|
||||
TableToken.OTSL_FCEL.value,
|
||||
TableToken.OTSL_ECEL.value,
|
||||
TableToken.OTSL_CHED.value,
|
||||
TableToken.OTSL_RHED.value,
|
||||
TableToken.OTSL_SROW.value,
|
||||
]:
|
||||
row_span = 1
|
||||
col_span = 1
|
||||
right_offset = 1
|
||||
if text != TableToken.OTSL_ECEL.value:
|
||||
cell_text = texts[i + 1]
|
||||
right_offset = 2
|
||||
|
||||
# Check next element(s) for lcel / ucel / xcel, set properly row_span, col_span
|
||||
next_right_cell = ""
|
||||
if i + right_offset < len(texts):
|
||||
next_right_cell = texts[i + right_offset]
|
||||
|
||||
next_bottom_cell = ""
|
||||
if r_idx + 1 < len(split_row_tokens):
|
||||
if c_idx < len(split_row_tokens[r_idx + 1]):
|
||||
next_bottom_cell = split_row_tokens[r_idx + 1][c_idx]
|
||||
|
||||
if next_right_cell in [
|
||||
TableToken.OTSL_LCEL.value,
|
||||
TableToken.OTSL_XCEL.value,
|
||||
]:
|
||||
# we have horisontal spanning cell or 2d spanning cell
|
||||
col_span += count_right(
|
||||
split_row_tokens,
|
||||
c_idx + 1,
|
||||
r_idx,
|
||||
[TableToken.OTSL_LCEL.value, TableToken.OTSL_XCEL.value],
|
||||
)
|
||||
if next_bottom_cell in [
|
||||
TableToken.OTSL_UCEL.value,
|
||||
TableToken.OTSL_XCEL.value,
|
||||
]:
|
||||
# we have a vertical spanning cell or 2d spanning cell
|
||||
row_span += count_down(
|
||||
split_row_tokens,
|
||||
c_idx,
|
||||
r_idx + 1,
|
||||
[TableToken.OTSL_UCEL.value, TableToken.OTSL_XCEL.value],
|
||||
)
|
||||
|
||||
table_cells.append(
|
||||
TableCell(
|
||||
text=cell_text.strip(),
|
||||
row_span=row_span,
|
||||
col_span=col_span,
|
||||
start_row_offset_idx=r_idx,
|
||||
end_row_offset_idx=r_idx + row_span,
|
||||
start_col_offset_idx=c_idx,
|
||||
end_col_offset_idx=c_idx + col_span,
|
||||
)
|
||||
)
|
||||
if text in [
|
||||
TableToken.OTSL_FCEL.value,
|
||||
TableToken.OTSL_ECEL.value,
|
||||
TableToken.OTSL_CHED.value,
|
||||
TableToken.OTSL_RHED.value,
|
||||
TableToken.OTSL_SROW.value,
|
||||
TableToken.OTSL_LCEL.value,
|
||||
TableToken.OTSL_UCEL.value,
|
||||
TableToken.OTSL_XCEL.value,
|
||||
]:
|
||||
c_idx += 1
|
||||
if text == TableToken.OTSL_NL.value:
|
||||
r_idx += 1
|
||||
c_idx = 0
|
||||
return table_cells, split_row_tokens
|
||||
|
||||
def otsl_extract_tokens_and_text(s: str):
|
||||
# Pattern to match anything enclosed by < > (including the angle brackets themselves)
|
||||
pattern = r"(<[^>]+>)"
|
||||
# Find all tokens (e.g. "<otsl>", "<loc_140>", etc.)
|
||||
tokens = re.findall(pattern, s)
|
||||
# Remove any tokens that start with "<loc_"
|
||||
tokens = [
|
||||
token
|
||||
for token in tokens
|
||||
if not (
|
||||
token.startswith(rf"<{DocumentToken.LOC.value}")
|
||||
or token
|
||||
in [
|
||||
rf"<{DocumentToken.OTSL.value}>",
|
||||
rf"</{DocumentToken.OTSL.value}>",
|
||||
]
|
||||
)
|
||||
]
|
||||
# Split the string by those tokens to get the in-between text
|
||||
text_parts = re.split(pattern, s)
|
||||
text_parts = [
|
||||
token
|
||||
for token in text_parts
|
||||
if not (
|
||||
token.startswith(rf"<{DocumentToken.LOC.value}")
|
||||
or token
|
||||
in [
|
||||
rf"<{DocumentToken.OTSL.value}>",
|
||||
rf"</{DocumentToken.OTSL.value}>",
|
||||
]
|
||||
)
|
||||
]
|
||||
# Remove any empty or purely whitespace strings from text_parts
|
||||
text_parts = [part for part in text_parts if part.strip()]
|
||||
|
||||
return tokens, text_parts
|
||||
|
||||
def parse_table_content(otsl_content: str) -> TableData:
|
||||
tokens, mixed_texts = otsl_extract_tokens_and_text(otsl_content)
|
||||
table_cells, split_row_tokens = otsl_parse_texts(mixed_texts, tokens)
|
||||
|
||||
return TableData(
|
||||
num_rows=len(split_row_tokens),
|
||||
num_cols=(
|
||||
max(len(row) for row in split_row_tokens) if split_row_tokens else 0
|
||||
),
|
||||
table_cells=table_cells,
|
||||
)
|
||||
|
||||
doc = DoclingDocument(name="Document")
|
||||
for pg_idx, page in enumerate(pages):
|
||||
xml_content = ""
|
||||
predicted_text = ""
|
||||
if page.predictions.vlm_response:
|
||||
predicted_text = page.predictions.vlm_response.text
|
||||
image = page.image
|
||||
|
||||
page_no = pg_idx + 1
|
||||
bounding_boxes = []
|
||||
|
||||
if page.size:
|
||||
pg_width = page.size.width
|
||||
pg_height = page.size.height
|
||||
size = Size(width=pg_width, height=pg_height)
|
||||
parent_page = doc.add_page(page_no=page_no, size=size)
|
||||
|
||||
"""
|
||||
1. Finds all <tag>...</tag> blocks in the entire string (multi-line friendly) in the order they appear.
|
||||
2. For each chunk, extracts bounding box (if any) and inner text.
|
||||
3. Adds the item to a DoclingDocument structure with the right label.
|
||||
4. Tracks bounding boxes + color in a separate list for later visualization.
|
||||
"""
|
||||
|
||||
# Regex for all recognized tags
|
||||
tag_pattern = (
|
||||
rf"<(?P<tag>{DocItemLabel.TITLE}|{DocItemLabel.DOCUMENT_INDEX}|"
|
||||
rf"{DocItemLabel.CHECKBOX_UNSELECTED}|{DocItemLabel.CHECKBOX_SELECTED}|"
|
||||
rf"{DocItemLabel.TEXT}|{DocItemLabel.PAGE_HEADER}|"
|
||||
rf"{DocItemLabel.PAGE_FOOTER}|{DocItemLabel.FORMULA}|"
|
||||
rf"{DocItemLabel.CAPTION}|{DocItemLabel.PICTURE}|"
|
||||
rf"{DocItemLabel.LIST_ITEM}|{DocItemLabel.FOOTNOTE}|{DocItemLabel.CODE}|"
|
||||
rf"{DocItemLabel.SECTION_HEADER}_level_1|{DocumentToken.OTSL.value})>.*?</(?P=tag)>"
|
||||
)
|
||||
|
||||
# DocumentToken.OTSL
|
||||
pattern = re.compile(tag_pattern, re.DOTALL)
|
||||
|
||||
# Go through each match in order
|
||||
for match in pattern.finditer(predicted_text):
|
||||
full_chunk = match.group(0)
|
||||
tag_name = match.group("tag")
|
||||
|
||||
bbox = extract_bounding_box(full_chunk)
|
||||
doc_label = tag_to_doclabel.get(tag_name, DocItemLabel.PARAGRAPH)
|
||||
color = tag_to_color.get(tag_name, "white")
|
||||
|
||||
# Store bounding box + color
|
||||
if bbox:
|
||||
bounding_boxes.append((bbox, color))
|
||||
|
||||
if tag_name == DocumentToken.OTSL.value:
|
||||
table_data = parse_table_content(full_chunk)
|
||||
bbox = extract_bounding_box(full_chunk)
|
||||
|
||||
if bbox:
|
||||
prov = ProvenanceItem(
|
||||
bbox=bbox.resize_by_scale(pg_width, pg_height),
|
||||
charspan=(0, 0),
|
||||
page_no=page_no,
|
||||
)
|
||||
doc.add_table(data=table_data, prov=prov)
|
||||
else:
|
||||
doc.add_table(data=table_data)
|
||||
|
||||
elif tag_name == DocItemLabel.PICTURE:
|
||||
text_caption_content = extract_inner_text(full_chunk)
|
||||
if image:
|
||||
if bbox:
|
||||
im_width, im_height = image.size
|
||||
|
||||
crop_box = (
|
||||
int(bbox.l * im_width),
|
||||
int(bbox.t * im_height),
|
||||
int(bbox.r * im_width),
|
||||
int(bbox.b * im_height),
|
||||
)
|
||||
cropped_image = image.crop(crop_box)
|
||||
pic = doc.add_picture(
|
||||
parent=None,
|
||||
image=ImageRef.from_pil(image=cropped_image, dpi=72),
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=bbox.resize_by_scale(pg_width, pg_height),
|
||||
charspan=(0, 0),
|
||||
page_no=page_no,
|
||||
)
|
||||
),
|
||||
)
|
||||
# If there is a caption to an image, add it as well
|
||||
if len(text_caption_content) > 0:
|
||||
caption_item = doc.add_text(
|
||||
label=DocItemLabel.CAPTION,
|
||||
text=text_caption_content,
|
||||
parent=None,
|
||||
)
|
||||
pic.captions.append(caption_item.get_ref())
|
||||
else:
|
||||
if bbox:
|
||||
# In case we don't have access to an binary of an image
|
||||
doc.add_picture(
|
||||
parent=None,
|
||||
prov=ProvenanceItem(
|
||||
bbox=bbox, charspan=(0, 0), page_no=page_no
|
||||
),
|
||||
)
|
||||
# If there is a caption to an image, add it as well
|
||||
if len(text_caption_content) > 0:
|
||||
caption_item = doc.add_text(
|
||||
label=DocItemLabel.CAPTION,
|
||||
text=text_caption_content,
|
||||
parent=None,
|
||||
)
|
||||
pic.captions.append(caption_item.get_ref())
|
||||
else:
|
||||
# For everything else, treat as text
|
||||
if self.force_backend_text:
|
||||
text_content = extract_text_from_backend(page, bbox)
|
||||
else:
|
||||
text_content = extract_inner_text(full_chunk)
|
||||
doc.add_text(
|
||||
label=doc_label,
|
||||
text=text_content,
|
||||
prov=(
|
||||
ProvenanceItem(
|
||||
bbox=bbox.resize_by_scale(pg_width, pg_height),
|
||||
charspan=(0, len(text_content)),
|
||||
page_no=page_no,
|
||||
)
|
||||
if bbox
|
||||
else None
|
||||
),
|
||||
)
|
||||
return doc
|
||||
|
||||
@classmethod
|
||||
def get_default_options(cls) -> VlmPipelineOptions:
|
||||
return VlmPipelineOptions()
|
||||
|
||||
@classmethod
|
||||
def is_backend_supported(cls, backend: AbstractDocumentBackend):
|
||||
return isinstance(backend, PdfDocumentBackend)
|
||||
@@ -43,6 +43,11 @@ def draw_clusters(
|
||||
y0 *= scale_x
|
||||
y1 *= scale_y
|
||||
|
||||
if y1 <= y0:
|
||||
y1, y0 = y0, y1
|
||||
if x1 <= x0:
|
||||
x1, x0 = x0, x1
|
||||
|
||||
cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70)
|
||||
cluster_outline_color = (
|
||||
*list(DocItemLabel.get_color(c.label)),
|
||||
|
||||
Reference in New Issue
Block a user