mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat(experimental): Layout + VLM model with layout prompt (#2244)
* adding granite-docling preview Signed-off-by: Peter Staar <taa@zurich.ibm.com> * updated the model specs Signed-off-by: Peter Staar <taa@zurich.ibm.com> * Add Layout+VLM pipeline with prompt injection, ApiVlmModel updates Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update layout injection, move to experimental Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Adjust defaults Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Map Layout+VLM pipeline to GraniteDoclign Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Remove base_prompt from layout injection prompt Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Reinstate custom prompt Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * add demo_layout file that produces with vs without layout injection Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * feat: wrap vlm_inference around process_images Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * feat: carry input prompt + number of input tokens Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * fix: adapt example to run on local test file Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * fix: example now expects single document Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * feat: add layout example to EXAMPLES_TO_SKIP Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * feat: address comments on git Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * feat: add inference wrapper for hf_transformers + carry input prompt Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * Feat: add track_input_prompt to ApiVlmOptions, and track input prompt as part of api vlm Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * fix: Ensure backward-compatible build_prompt by adding _internal_page ag Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * fix: Ensure backward-compatible build_prompt by adding _internal_page ag Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fixes for demo Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Typing fixes Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Restoring lost changes in vllm_model Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Restoring vlm_pipeline_api_model example Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Peter Staar <taa@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> Co-authored-by: Peter Staar <taa@zurich.ibm.com> Co-authored-by: ElHachem02 <peterelhachem02@gmail.com>
This commit is contained in:
14
.github/workflows/checks.yml
vendored
14
.github/workflows/checks.yml
vendored
@@ -20,7 +20,7 @@ env:
|
|||||||
tests/test_asr_pipeline.py
|
tests/test_asr_pipeline.py
|
||||||
tests/test_threaded_pipeline.py
|
tests/test_threaded_pipeline.py
|
||||||
PYTEST_TO_SKIP: |-
|
PYTEST_TO_SKIP: |-
|
||||||
EXAMPLES_TO_SKIP: '^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model|granitedocling_repetition_stopping|mlx_whisper_example|gpu_standard_pipeline|gpu_vlm_pipeline)\.py$'
|
EXAMPLES_TO_SKIP: '^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model|granitedocling_repetition_stopping|mlx_whisper_example|gpu_standard_pipeline|gpu_vlm_pipeline|demo_layout_vlm)\.py$'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint:
|
||||||
@@ -28,7 +28,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.12']
|
python-version: ["3.12"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14']
|
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14']
|
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
@@ -201,7 +201,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14']
|
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
@@ -291,7 +291,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.12']
|
python-version: ["3.12"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
@@ -322,7 +322,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.12']
|
python-version: ["3.12"]
|
||||||
steps:
|
steps:
|
||||||
- name: Download all the dists
|
- name: Download all the dists
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ class VlmPrediction(BaseModel):
|
|||||||
generation_time: float = -1
|
generation_time: float = -1
|
||||||
num_tokens: Optional[int] = None
|
num_tokens: Optional[int] = None
|
||||||
stop_reason: VlmStopReason = VlmStopReason.UNSPECIFIED
|
stop_reason: VlmStopReason = VlmStopReason.UNSPECIFIED
|
||||||
|
input_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ContainerElement(
|
class ContainerElement(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from docling_core.types.doc.page import SegmentedPage
|
from docling_core.types.doc.page import SegmentedPage
|
||||||
from pydantic import AnyUrl, BaseModel, ConfigDict
|
from pydantic import AnyUrl, BaseModel, ConfigDict
|
||||||
@@ -9,6 +9,11 @@ from typing_extensions import deprecated
|
|||||||
from docling.datamodel.accelerator_options import AcceleratorDevice
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||||
from docling.models.utils.generation_utils import GenerationStopper
|
from docling.models.utils.generation_utils import GenerationStopper
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from docling_core.types.doc.page import SegmentedPage
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import Page
|
||||||
|
|
||||||
|
|
||||||
class BaseVlmOptions(BaseModel):
|
class BaseVlmOptions(BaseModel):
|
||||||
kind: str
|
kind: str
|
||||||
@@ -17,7 +22,22 @@ class BaseVlmOptions(BaseModel):
|
|||||||
max_size: Optional[int] = None
|
max_size: Optional[int] = None
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
|
|
||||||
def build_prompt(self, page: Optional[SegmentedPage]) -> str:
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
page: Optional["SegmentedPage"],
|
||||||
|
*,
|
||||||
|
_internal_page: Optional["Page"] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build the prompt for VLM inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: The parsed/segmented page to process.
|
||||||
|
_internal_page: Internal parameter for experimental layout-aware pipelines.
|
||||||
|
Do not rely on this in user code - subject to change.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The formatted prompt string.
|
||||||
|
"""
|
||||||
return self.prompt
|
return self.prompt
|
||||||
|
|
||||||
def decode_response(self, text: str) -> str:
|
def decode_response(self, text: str) -> str:
|
||||||
@@ -83,6 +103,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
|||||||
use_kv_cache: bool = True
|
use_kv_cache: bool = True
|
||||||
max_new_tokens: int = 4096
|
max_new_tokens: int = 4096
|
||||||
track_generated_tokens: bool = False
|
track_generated_tokens: bool = False
|
||||||
|
track_input_prompt: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_cache_folder(self) -> str:
|
def repo_cache_folder(self) -> str:
|
||||||
@@ -110,3 +131,4 @@ class ApiVlmOptions(BaseVlmOptions):
|
|||||||
|
|
||||||
stop_strings: List[str] = []
|
stop_strings: List[str] = []
|
||||||
custom_stopping_criteria: List[Union[GenerationStopper]] = []
|
custom_stopping_criteria: List[Union[GenerationStopper]] = []
|
||||||
|
track_input_prompt: bool = False
|
||||||
|
|||||||
5
docling/experimental/__init__.py
Normal file
5
docling/experimental/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Experimental modules for Docling.
|
||||||
|
|
||||||
|
This package contains experimental features that are under development
|
||||||
|
and may change or be removed in future versions.
|
||||||
|
"""
|
||||||
1
docling/experimental/datamodel/__init__.py
Normal file
1
docling/experimental/datamodel/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Experimental datamodel modules."""
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""Options for the threaded layout+VLM pipeline."""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from pydantic import model_validator
|
||||||
|
|
||||||
|
from docling.datamodel.layout_model_specs import DOCLING_LAYOUT_HERON
|
||||||
|
from docling.datamodel.pipeline_options import LayoutOptions, PaginatedPipelineOptions
|
||||||
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
|
ApiVlmOptions,
|
||||||
|
InlineVlmOptions,
|
||||||
|
ResponseFormat,
|
||||||
|
)
|
||||||
|
from docling.datamodel.vlm_model_specs import GRANITEDOCLING_TRANSFORMERS
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedLayoutVlmPipelineOptions(PaginatedPipelineOptions):
|
||||||
|
"""Pipeline options for the threaded layout+VLM pipeline."""
|
||||||
|
|
||||||
|
images_scale: float = 2.0
|
||||||
|
|
||||||
|
# VLM configuration (will be enhanced with layout awareness by the pipeline)
|
||||||
|
vlm_options: Union[InlineVlmOptions, ApiVlmOptions] = GRANITEDOCLING_TRANSFORMERS
|
||||||
|
|
||||||
|
# Layout model configuration
|
||||||
|
layout_options: LayoutOptions = LayoutOptions(
|
||||||
|
model_spec=DOCLING_LAYOUT_HERON, skip_cell_assignment=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Threading and batching controls
|
||||||
|
layout_batch_size: int = 4
|
||||||
|
vlm_batch_size: int = 4
|
||||||
|
batch_timeout_seconds: float = 2.0
|
||||||
|
queue_max_size: int = 50
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_response_format(self):
|
||||||
|
"""Validate that VLM response format is DOCTAGS (required for this pipeline)."""
|
||||||
|
if self.vlm_options.response_format != ResponseFormat.DOCTAGS:
|
||||||
|
raise ValueError(
|
||||||
|
f"ThreadedLayoutVlmPipeline only supports DOCTAGS response format, "
|
||||||
|
f"but got {self.vlm_options.response_format}. "
|
||||||
|
f"Please set vlm_options.response_format=ResponseFormat.DOCTAGS"
|
||||||
|
)
|
||||||
|
return self
|
||||||
1
docling/experimental/pipeline/__init__.py
Normal file
1
docling/experimental/pipeline/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Experimental pipeline modules."""
|
||||||
436
docling/experimental/pipeline/threaded_layout_vlm_pipeline.py
Normal file
436
docling/experimental/pipeline/threaded_layout_vlm_pipeline.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
"""Threaded Layout+VLM Pipeline
|
||||||
|
================================
|
||||||
|
A specialized two-stage threaded pipeline that combines layout model preprocessing
|
||||||
|
with VLM processing. The layout model detects document elements and coordinates,
|
||||||
|
which are then injected into the VLM prompt for enhanced structured output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||||
|
|
||||||
|
from docling_core.types.doc import DoclingDocument
|
||||||
|
from docling_core.types.doc.document import DocTagsDocument
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from docling_core.types.doc.page import SegmentedPage
|
||||||
|
|
||||||
|
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||||
|
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||||
|
from docling.datamodel.base_models import ConversionStatus, Page
|
||||||
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
|
ApiVlmOptions,
|
||||||
|
InferenceFramework,
|
||||||
|
InlineVlmOptions,
|
||||||
|
)
|
||||||
|
from docling.datamodel.settings import settings
|
||||||
|
from docling.experimental.datamodel.threaded_layout_vlm_pipeline_options import (
|
||||||
|
ThreadedLayoutVlmPipelineOptions,
|
||||||
|
)
|
||||||
|
from docling.models.api_vlm_model import ApiVlmModel
|
||||||
|
from docling.models.base_model import BaseVlmPageModel
|
||||||
|
from docling.models.layout_model import LayoutModel
|
||||||
|
from docling.models.vlm_models_inline.hf_transformers_model import (
|
||||||
|
HuggingFaceTransformersVlmModel,
|
||||||
|
)
|
||||||
|
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
|
||||||
|
from docling.pipeline.base_pipeline import BasePipeline
|
||||||
|
from docling.pipeline.standard_pdf_pipeline import (
|
||||||
|
ProcessingResult,
|
||||||
|
RunContext,
|
||||||
|
ThreadedItem,
|
||||||
|
ThreadedPipelineStage,
|
||||||
|
ThreadedQueue,
|
||||||
|
)
|
||||||
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedLayoutVlmPipeline(BasePipeline):
|
||||||
|
"""Two-stage threaded pipeline: Layout Model → VLM Model."""
|
||||||
|
|
||||||
|
def __init__(self, pipeline_options: ThreadedLayoutVlmPipelineOptions) -> None:
|
||||||
|
super().__init__(pipeline_options)
|
||||||
|
self.pipeline_options: ThreadedLayoutVlmPipelineOptions = pipeline_options
|
||||||
|
self._run_seq = itertools.count(1) # deterministic, monotonic run ids
|
||||||
|
|
||||||
|
# VLM model type (initialized in _init_models)
|
||||||
|
self.vlm_model: BaseVlmPageModel
|
||||||
|
|
||||||
|
# Initialize models
|
||||||
|
self._init_models()
|
||||||
|
|
||||||
|
def _init_models(self) -> None:
|
||||||
|
"""Initialize layout and VLM models."""
|
||||||
|
art_path = self._resolve_artifacts_path()
|
||||||
|
|
||||||
|
# Layout model
|
||||||
|
self.layout_model = LayoutModel(
|
||||||
|
artifacts_path=art_path,
|
||||||
|
accelerator_options=self.pipeline_options.accelerator_options,
|
||||||
|
options=self.pipeline_options.layout_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# VLM model based on options type
|
||||||
|
# Create layout-aware VLM options internally
|
||||||
|
base_vlm_options = self.pipeline_options.vlm_options
|
||||||
|
|
||||||
|
class LayoutAwareVlmOptions(type(base_vlm_options)): # type: ignore[misc]
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
page: Optional[SegmentedPage],
|
||||||
|
*,
|
||||||
|
_internal_page: Optional[Page] = None,
|
||||||
|
) -> str:
|
||||||
|
base_prompt = self.prompt
|
||||||
|
augmented_prompt = base_prompt
|
||||||
|
|
||||||
|
# In this layout-aware pipeline, _internal_page is always provided
|
||||||
|
if _internal_page is None:
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
if not _internal_page.size:
|
||||||
|
_log.warning(
|
||||||
|
f"Page size not available for page {_internal_page.page_no}. Cannot enhance prompt with layout info."
|
||||||
|
)
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
if _internal_page.predictions.layout:
|
||||||
|
from docling_core.types.doc.tokens import DocumentToken
|
||||||
|
|
||||||
|
layout_elements = []
|
||||||
|
for cluster in _internal_page.predictions.layout.clusters:
|
||||||
|
# Get proper tag name from DocItemLabel
|
||||||
|
tag_name = DocumentToken.create_token_name_from_doc_item_label(
|
||||||
|
label=cluster.label
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert bbox to tuple and get location tokens
|
||||||
|
bbox_tuple = cluster.bbox.as_tuple()
|
||||||
|
location_tokens = DocumentToken.get_location(
|
||||||
|
bbox=bbox_tuple,
|
||||||
|
page_w=_internal_page.size.width,
|
||||||
|
page_h=_internal_page.size.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create XML element with DocTags format
|
||||||
|
xml_element = f"<{tag_name}>{location_tokens}</{tag_name}>"
|
||||||
|
layout_elements.append(xml_element)
|
||||||
|
|
||||||
|
if layout_elements:
|
||||||
|
# Join elements with newlines and wrap in layout tags
|
||||||
|
layout_xml = (
|
||||||
|
"<layout>" + "\n".join(layout_elements) + "</layout>"
|
||||||
|
)
|
||||||
|
layout_injection = f"{layout_xml}"
|
||||||
|
|
||||||
|
augmented_prompt = base_prompt + layout_injection
|
||||||
|
|
||||||
|
_log.debug(
|
||||||
|
"Enhanced Prompt with Layout Info: %s\n", augmented_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
return augmented_prompt
|
||||||
|
|
||||||
|
vlm_options = LayoutAwareVlmOptions(**base_vlm_options.model_dump())
|
||||||
|
|
||||||
|
if isinstance(base_vlm_options, ApiVlmOptions):
|
||||||
|
self.vlm_model = ApiVlmModel(
|
||||||
|
enabled=True,
|
||||||
|
enable_remote_services=self.pipeline_options.enable_remote_services,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
)
|
||||||
|
elif isinstance(base_vlm_options, InlineVlmOptions):
|
||||||
|
if vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
|
||||||
|
self.vlm_model = HuggingFaceTransformersVlmModel(
|
||||||
|
enabled=True,
|
||||||
|
artifacts_path=art_path,
|
||||||
|
accelerator_options=self.pipeline_options.accelerator_options,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
)
|
||||||
|
elif vlm_options.inference_framework == InferenceFramework.MLX:
|
||||||
|
self.vlm_model = HuggingFaceMlxModel(
|
||||||
|
enabled=True,
|
||||||
|
artifacts_path=art_path,
|
||||||
|
accelerator_options=self.pipeline_options.accelerator_options,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
)
|
||||||
|
elif vlm_options.inference_framework == InferenceFramework.VLLM:
|
||||||
|
from docling.models.vlm_models_inline.vllm_model import VllmVlmModel
|
||||||
|
|
||||||
|
self.vlm_model = VllmVlmModel(
|
||||||
|
enabled=True,
|
||||||
|
artifacts_path=art_path,
|
||||||
|
accelerator_options=self.pipeline_options.accelerator_options,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported VLM inference framework: {vlm_options.inference_framework}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported VLM options type: {type(base_vlm_options)}")
|
||||||
|
|
||||||
|
def _resolve_artifacts_path(self) -> Optional[Path]:
|
||||||
|
"""Resolve artifacts path from options or settings."""
|
||||||
|
if self.pipeline_options.artifacts_path:
|
||||||
|
p = Path(self.pipeline_options.artifacts_path).expanduser()
|
||||||
|
elif settings.artifacts_path:
|
||||||
|
p = Path(settings.artifacts_path).expanduser()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if not p.is_dir():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{p} does not exist or is not a directory containing the required models"
|
||||||
|
)
|
||||||
|
return p
|
||||||
|
|
||||||
|
def _create_run_ctx(self) -> RunContext:
|
||||||
|
"""Create pipeline stages and wire them together."""
|
||||||
|
opts = self.pipeline_options
|
||||||
|
|
||||||
|
# Layout stage
|
||||||
|
layout_stage = ThreadedPipelineStage(
|
||||||
|
name="layout",
|
||||||
|
model=self.layout_model,
|
||||||
|
batch_size=opts.layout_batch_size,
|
||||||
|
batch_timeout=opts.batch_timeout_seconds,
|
||||||
|
queue_max_size=opts.queue_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# VLM stage - now layout-aware through enhanced build_prompt
|
||||||
|
vlm_stage = ThreadedPipelineStage(
|
||||||
|
name="vlm",
|
||||||
|
model=self.vlm_model,
|
||||||
|
batch_size=opts.vlm_batch_size,
|
||||||
|
batch_timeout=opts.batch_timeout_seconds,
|
||||||
|
queue_max_size=opts.queue_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wire stages
|
||||||
|
output_q = ThreadedQueue(opts.queue_max_size)
|
||||||
|
layout_stage.add_output_queue(vlm_stage.input_queue)
|
||||||
|
vlm_stage.add_output_queue(output_q)
|
||||||
|
|
||||||
|
stages = [layout_stage, vlm_stage]
|
||||||
|
return RunContext(
|
||||||
|
stages=stages, first_stage=layout_stage, output_queue=output_q
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
|
"""Build document using threaded layout+VLM pipeline."""
|
||||||
|
run_id = next(self._run_seq)
|
||||||
|
assert isinstance(conv_res.input._backend, PdfDocumentBackend)
|
||||||
|
backend = conv_res.input._backend
|
||||||
|
|
||||||
|
# Initialize pages
|
||||||
|
start_page, end_page = conv_res.input.limits.page_range
|
||||||
|
pages: List[Page] = []
|
||||||
|
for i in range(conv_res.input.page_count):
|
||||||
|
if start_page - 1 <= i <= end_page - 1:
|
||||||
|
page = Page(page_no=i)
|
||||||
|
page._backend = backend.load_page(i)
|
||||||
|
if page._backend and page._backend.is_valid():
|
||||||
|
page.size = page._backend.get_size()
|
||||||
|
conv_res.pages.append(page)
|
||||||
|
pages.append(page)
|
||||||
|
|
||||||
|
if not pages:
|
||||||
|
conv_res.status = ConversionStatus.FAILURE
|
||||||
|
return conv_res
|
||||||
|
|
||||||
|
total_pages = len(pages)
|
||||||
|
ctx = self._create_run_ctx()
|
||||||
|
for st in ctx.stages:
|
||||||
|
st.start()
|
||||||
|
|
||||||
|
proc = ProcessingResult(total_expected=total_pages)
|
||||||
|
fed_idx = 0
|
||||||
|
batch_size = 32
|
||||||
|
|
||||||
|
try:
|
||||||
|
while proc.success_count + proc.failure_count < total_pages:
|
||||||
|
# Feed pages to first stage
|
||||||
|
while fed_idx < total_pages:
|
||||||
|
ok = ctx.first_stage.input_queue.put(
|
||||||
|
ThreadedItem(
|
||||||
|
payload=pages[fed_idx],
|
||||||
|
run_id=run_id,
|
||||||
|
page_no=pages[fed_idx].page_no,
|
||||||
|
conv_res=conv_res,
|
||||||
|
),
|
||||||
|
timeout=0.0,
|
||||||
|
)
|
||||||
|
if ok:
|
||||||
|
fed_idx += 1
|
||||||
|
if fed_idx == total_pages:
|
||||||
|
ctx.first_stage.input_queue.close()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Drain results from output
|
||||||
|
out_batch = ctx.output_queue.get_batch(batch_size, timeout=0.05)
|
||||||
|
for itm in out_batch:
|
||||||
|
if itm.run_id != run_id:
|
||||||
|
continue
|
||||||
|
if itm.is_failed or itm.error:
|
||||||
|
proc.failed_pages.append(
|
||||||
|
(itm.page_no, itm.error or RuntimeError("unknown error"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert itm.payload is not None
|
||||||
|
proc.pages.append(itm.payload)
|
||||||
|
|
||||||
|
# Handle early termination
|
||||||
|
if not out_batch and ctx.output_queue.closed:
|
||||||
|
missing = total_pages - (proc.success_count + proc.failure_count)
|
||||||
|
if missing > 0:
|
||||||
|
proc.failed_pages.extend(
|
||||||
|
[(-1, RuntimeError("pipeline terminated early"))] * missing
|
||||||
|
)
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
for st in ctx.stages:
|
||||||
|
st.stop()
|
||||||
|
ctx.output_queue.close()
|
||||||
|
|
||||||
|
self._integrate_results(conv_res, proc)
|
||||||
|
return conv_res
|
||||||
|
|
||||||
|
def _integrate_results(
|
||||||
|
self, conv_res: ConversionResult, proc: ProcessingResult
|
||||||
|
) -> None:
|
||||||
|
"""Integrate processing results into conversion result."""
|
||||||
|
page_map = {p.page_no: p for p in proc.pages}
|
||||||
|
|
||||||
|
# Track failed pages for cleanup
|
||||||
|
failed_page_nos = {fp for fp, _ in proc.failed_pages}
|
||||||
|
|
||||||
|
# Collect pages that will be removed (failed pages) for resource cleanup
|
||||||
|
pages_to_remove = [p for p in conv_res.pages if p.page_no in failed_page_nos]
|
||||||
|
|
||||||
|
conv_res.pages = [
|
||||||
|
page_map.get(p.page_no, p)
|
||||||
|
for p in conv_res.pages
|
||||||
|
if p.page_no in page_map
|
||||||
|
or not any(fp == p.page_no for fp, _ in proc.failed_pages)
|
||||||
|
]
|
||||||
|
|
||||||
|
if proc.is_complete_failure:
|
||||||
|
conv_res.status = ConversionStatus.FAILURE
|
||||||
|
elif proc.is_partial_success:
|
||||||
|
conv_res.status = ConversionStatus.PARTIAL_SUCCESS
|
||||||
|
else:
|
||||||
|
conv_res.status = ConversionStatus.SUCCESS
|
||||||
|
|
||||||
|
# Clean up resources for failed pages that were removed
|
||||||
|
for p in pages_to_remove:
|
||||||
|
if p._backend is not None:
|
||||||
|
p._backend.unload()
|
||||||
|
p._image_cache = {}
|
||||||
|
# Clean up parsed_page if it exists (it's Optional[SegmentedPdfPage])
|
||||||
|
if p.parsed_page is not None:
|
||||||
|
del p.parsed_page
|
||||||
|
p.parsed_page = None
|
||||||
|
|
||||||
|
# Clean up images if not needed for remaining pages
|
||||||
|
if not self.pipeline_options.generate_page_images:
|
||||||
|
for p in conv_res.pages:
|
||||||
|
p._image_cache = {}
|
||||||
|
|
||||||
|
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
|
"""Assemble final document from VLM predictions."""
|
||||||
|
from docling_core.types.doc import DocItem, ImageRef, PictureItem
|
||||||
|
|
||||||
|
from docling.datamodel.pipeline_options_vlm_model import ResponseFormat
|
||||||
|
|
||||||
|
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
|
||||||
|
# Response format validation is done in ThreadedLayoutVlmPipelineOptions
|
||||||
|
# This check is kept as a safety net, but should never trigger if validation works
|
||||||
|
if (
|
||||||
|
self.pipeline_options.vlm_options.response_format
|
||||||
|
!= ResponseFormat.DOCTAGS
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unsupported VLM response format {self.pipeline_options.vlm_options.response_format}. Only DOCTAGS format is supported."
|
||||||
|
)
|
||||||
|
conv_res.document = self._turn_dt_into_doc(conv_res)
|
||||||
|
|
||||||
|
# Generate images of the requested element types
|
||||||
|
if self.pipeline_options.generate_picture_images:
|
||||||
|
# Create mapping from page_no to Page object since pages may be non-continuous
|
||||||
|
page_map = {p.page_no: p for p in conv_res.pages}
|
||||||
|
scale = self.pipeline_options.images_scale
|
||||||
|
for element, _level in conv_res.document.iterate_items():
|
||||||
|
if not isinstance(element, DocItem) or len(element.prov) == 0:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
isinstance(element, PictureItem)
|
||||||
|
and self.pipeline_options.generate_picture_images
|
||||||
|
):
|
||||||
|
page_no = element.prov[0].page_no
|
||||||
|
page = page_map.get(page_no)
|
||||||
|
if page is None:
|
||||||
|
_log.warning(
|
||||||
|
f"Page {page_no} not found in conversion result for picture element. Skipping image generation."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
assert page.size is not None
|
||||||
|
assert page.image is not None
|
||||||
|
|
||||||
|
crop_bbox = (
|
||||||
|
element.prov[0]
|
||||||
|
.bbox.scaled(scale=scale)
|
||||||
|
.to_top_left_origin(page_height=page.size.height * scale)
|
||||||
|
)
|
||||||
|
|
||||||
|
cropped_im = page.image.crop(crop_bbox.as_tuple())
|
||||||
|
element.image = ImageRef.from_pil(
|
||||||
|
cropped_im, dpi=int(72 * scale)
|
||||||
|
)
|
||||||
|
|
||||||
|
return conv_res
|
||||||
|
|
||||||
|
def _turn_dt_into_doc(self, conv_res: ConversionResult) -> DoclingDocument:
|
||||||
|
"""Convert DOCTAGS response format to DoclingDocument."""
|
||||||
|
doctags_list = []
|
||||||
|
image_list = []
|
||||||
|
for page in conv_res.pages:
|
||||||
|
# Only include pages that have both an image and VLM predictions
|
||||||
|
if page.image and page.predictions.vlm_response:
|
||||||
|
predicted_doctags = page.predictions.vlm_response.text
|
||||||
|
image_list.append(page.image)
|
||||||
|
doctags_list.append(predicted_doctags)
|
||||||
|
|
||||||
|
doctags_list_c = cast(List[Union[Path, str]], doctags_list)
|
||||||
|
image_list_c = cast(List[Union[Path, PILImage.Image]], image_list)
|
||||||
|
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs(
|
||||||
|
doctags_list_c, image_list_c
|
||||||
|
)
|
||||||
|
document = DoclingDocument.load_from_doctags(doctag_document=doctags_doc)
|
||||||
|
|
||||||
|
return document
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_options(cls) -> ThreadedLayoutVlmPipelineOptions:
|
||||||
|
return ThreadedLayoutVlmPipelineOptions()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_backend_supported(cls, backend: AbstractDocumentBackend) -> bool:
|
||||||
|
return isinstance(backend, PdfDocumentBackend)
|
||||||
|
|
||||||
|
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||||
|
return conv_res.status
|
||||||
|
|
||||||
|
def _unload(self, conv_res: ConversionResult) -> None:
|
||||||
|
for p in conv_res.pages:
|
||||||
|
if p._backend is not None:
|
||||||
|
p._backend.unload()
|
||||||
|
if conv_res.input._backend:
|
||||||
|
conv_res.input._backend.unload()
|
||||||
@@ -1,13 +1,15 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from transformers import StoppingCriteria
|
import numpy as np
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
||||||
from docling.exceptions import OperationNotAllowed
|
from docling.exceptions import OperationNotAllowed
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BaseVlmPageModel
|
||||||
from docling.models.utils.generation_utils import GenerationStopper
|
from docling.models.utils.generation_utils import GenerationStopper
|
||||||
from docling.utils.api_image_request import (
|
from docling.utils.api_image_request import (
|
||||||
api_image_request,
|
api_image_request,
|
||||||
@@ -16,7 +18,10 @@ from docling.utils.api_image_request import (
|
|||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
|
|
||||||
class ApiVlmModel(BasePageModel):
|
class ApiVlmModel(BaseVlmPageModel):
|
||||||
|
# Override the vlm_options type annotation from BaseVlmPageModel
|
||||||
|
vlm_options: ApiVlmOptions # type: ignore[assignment]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
@@ -43,22 +48,90 @@ class ApiVlmModel(BasePageModel):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
def _vlm_request(page):
|
page_list = list(page_batch)
|
||||||
|
if not page_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
original_order = page_list[:]
|
||||||
|
valid_pages = []
|
||||||
|
|
||||||
|
for page in page_list:
|
||||||
assert page._backend is not None
|
assert page._backend is not None
|
||||||
if not page._backend.is_valid():
|
if page._backend.is_valid():
|
||||||
return page
|
valid_pages.append(page)
|
||||||
|
|
||||||
|
# Process valid pages in batch
|
||||||
|
if valid_pages:
|
||||||
with TimeRecorder(conv_res, "vlm"):
|
with TimeRecorder(conv_res, "vlm"):
|
||||||
assert page.size is not None
|
# Prepare images and prompts for batch processing
|
||||||
|
images = []
|
||||||
|
prompts = []
|
||||||
|
pages_with_images = []
|
||||||
|
|
||||||
|
for page in valid_pages:
|
||||||
|
assert page.size is not None
|
||||||
hi_res_image = page.get_image(
|
hi_res_image = page.get_image(
|
||||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||||
)
|
)
|
||||||
assert hi_res_image is not None
|
|
||||||
if hi_res_image and hi_res_image.mode != "RGB":
|
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
|
||||||
|
|
||||||
prompt = self.vlm_options.build_prompt(page.parsed_page)
|
# Only process pages with valid images
|
||||||
|
if hi_res_image is not None:
|
||||||
|
images.append(hi_res_image)
|
||||||
|
prompt = self._build_prompt_safe(page)
|
||||||
|
prompts.append(prompt)
|
||||||
|
pages_with_images.append(page)
|
||||||
|
|
||||||
|
# Use process_images for the actual inference
|
||||||
|
if images: # Only if we have valid images
|
||||||
|
with TimeRecorder(conv_res, "vlm_inference"):
|
||||||
|
predictions = list(self.process_images(images, prompts))
|
||||||
|
|
||||||
|
# Attach results to pages
|
||||||
|
for page, prediction in zip(pages_with_images, predictions):
|
||||||
|
page.predictions.vlm_response = prediction
|
||||||
|
|
||||||
|
# Yield pages preserving original order
|
||||||
|
for page in original_order:
|
||||||
|
yield page
|
||||||
|
|
||||||
|
def process_images(
|
||||||
|
self,
|
||||||
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
|
prompt: Union[str, list[str]],
|
||||||
|
) -> Iterable[VlmPrediction]:
|
||||||
|
"""Process raw images without page metadata."""
|
||||||
|
images = list(image_batch)
|
||||||
|
|
||||||
|
# Handle prompt parameter
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompts = [prompt] * len(images)
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
if len(prompt) != len(images):
|
||||||
|
raise ValueError(
|
||||||
|
f"Prompt list length ({len(prompt)}) must match image count ({len(images)})"
|
||||||
|
)
|
||||||
|
prompts = prompt
|
||||||
|
|
||||||
|
def _process_single_image(image_prompt_pair):
|
||||||
|
image, prompt_text = image_prompt_pair
|
||||||
|
|
||||||
|
# Convert numpy array to PIL Image if needed
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
image = PILImage.fromarray(image.astype(np.uint8))
|
||||||
|
elif image.ndim == 2:
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported numpy array shape: {image.shape}")
|
||||||
|
|
||||||
|
# Ensure image is in RGB mode
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
stop_reason = VlmStopReason.UNSPECIFIED
|
stop_reason = VlmStopReason.UNSPECIFIED
|
||||||
|
|
||||||
if self.vlm_options.custom_stopping_criteria:
|
if self.vlm_options.custom_stopping_criteria:
|
||||||
@@ -74,23 +147,20 @@ class ApiVlmModel(BasePageModel):
|
|||||||
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
||||||
|
|
||||||
# Streaming path with early abort support
|
# Streaming path with early abort support
|
||||||
with TimeRecorder(conv_res, "vlm_inference"):
|
|
||||||
page_tags, num_tokens = api_image_request_streaming(
|
page_tags, num_tokens = api_image_request_streaming(
|
||||||
image=hi_res_image,
|
image=image,
|
||||||
prompt=prompt,
|
prompt=prompt_text,
|
||||||
url=self.vlm_options.url,
|
url=self.vlm_options.url,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
headers=self.vlm_options.headers,
|
headers=self.vlm_options.headers,
|
||||||
generation_stoppers=instantiated_stoppers,
|
generation_stoppers=instantiated_stoppers,
|
||||||
**self.params,
|
**self.params,
|
||||||
)
|
)
|
||||||
page_tags = self.vlm_options.decode_response(page_tags)
|
|
||||||
else:
|
else:
|
||||||
# Non-streaming fallback (existing behavior)
|
# Non-streaming fallback (existing behavior)
|
||||||
with TimeRecorder(conv_res, "vlm_inference"):
|
|
||||||
page_tags, num_tokens, stop_reason = api_image_request(
|
page_tags, num_tokens, stop_reason = api_image_request(
|
||||||
image=hi_res_image,
|
image=image,
|
||||||
prompt=prompt,
|
prompt=prompt_text,
|
||||||
url=self.vlm_options.url,
|
url=self.vlm_options.url,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
headers=self.vlm_options.headers,
|
headers=self.vlm_options.headers,
|
||||||
@@ -98,11 +168,13 @@ class ApiVlmModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
page_tags = self.vlm_options.decode_response(page_tags)
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
|
input_prompt = prompt_text if self.vlm_options.track_input_prompt else None
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
return VlmPrediction(
|
||||||
text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason
|
text=page_tags,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
input_prompt=input_prompt,
|
||||||
)
|
)
|
||||||
return page
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
||||||
yield from executor.map(_vlm_request, page_batch)
|
yield from executor.map(_process_single_image, zip(images, prompts))
|
||||||
|
|||||||
@@ -76,6 +76,24 @@ class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
|||||||
vlm_options: InlineVlmOptions
|
vlm_options: InlineVlmOptions
|
||||||
processor: Any
|
processor: Any
|
||||||
|
|
||||||
|
def _build_prompt_safe(self, page: Page) -> str:
|
||||||
|
"""Build prompt with backward compatibility for user overrides.
|
||||||
|
|
||||||
|
Tries to call build_prompt with _internal_page parameter (for layout-aware
|
||||||
|
pipelines). Falls back to basic call if user override doesn't accept it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: The full Page object with layout predictions and parsed_page.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The formatted prompt string.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.vlm_options.build_prompt(page.parsed_page, _internal_page=page)
|
||||||
|
except TypeError:
|
||||||
|
# User override doesn't accept _internal_page - fall back to basic call
|
||||||
|
return self.vlm_options.build_prompt(page.parsed_page)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
|||||||
@@ -176,13 +176,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
images.append(hi_res_image)
|
images.append(hi_res_image)
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
user_prompt = self._build_prompt_safe(page)
|
||||||
|
|
||||||
user_prompts.append(user_prompt)
|
user_prompts.append(user_prompt)
|
||||||
pages_with_images.append(page)
|
pages_with_images.append(page)
|
||||||
|
|
||||||
# Use process_images for the actual inference
|
# Use process_images for the actual inference
|
||||||
if images: # Only if we have valid images
|
if images: # Only if we have valid images
|
||||||
|
with TimeRecorder(conv_res, "vlm_inference"):
|
||||||
predictions = list(self.process_images(images, user_prompts))
|
predictions = list(self.process_images(images, user_prompts))
|
||||||
|
|
||||||
# Attach results to pages
|
# Attach results to pages
|
||||||
@@ -375,7 +376,10 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
f"for batch size {generated_ids.shape[0]}."
|
f"for batch size {generated_ids.shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
for text in decoded_texts:
|
for i, text in enumerate(decoded_texts):
|
||||||
|
input_prompt = (
|
||||||
|
prompts[i] if self.vlm_options.track_input_prompt and prompts else None
|
||||||
|
)
|
||||||
# Apply decode_response to the output text
|
# Apply decode_response to the output text
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(
|
yield VlmPrediction(
|
||||||
@@ -383,4 +387,5 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
stop_reason=VlmStopReason.UNSPECIFIED,
|
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||||
|
input_prompt=input_prompt,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -134,10 +134,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
images.append(hi_res_image)
|
images.append(hi_res_image)
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
if callable(self.vlm_options.prompt):
|
user_prompt = self._build_prompt_safe(page)
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
|
||||||
else:
|
|
||||||
user_prompt = self.vlm_options.prompt
|
|
||||||
|
|
||||||
user_prompts.append(user_prompt)
|
user_prompts.append(user_prompt)
|
||||||
pages_with_images.append(page)
|
pages_with_images.append(page)
|
||||||
@@ -319,11 +316,15 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
|
|
||||||
# Apply decode_response to the output before yielding
|
# Apply decode_response to the output before yielding
|
||||||
decoded_output = self.vlm_options.decode_response(output)
|
decoded_output = self.vlm_options.decode_response(output)
|
||||||
|
input_prompt = (
|
||||||
|
formatted_prompt if self.vlm_options.track_input_prompt else None
|
||||||
|
)
|
||||||
yield VlmPrediction(
|
yield VlmPrediction(
|
||||||
text=decoded_output,
|
text=decoded_output,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
generated_tokens=tokens,
|
generated_tokens=tokens,
|
||||||
num_tokens=len(tokens),
|
num_tokens=len(tokens),
|
||||||
stop_reason=VlmStopReason.UNSPECIFIED,
|
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||||
|
input_prompt=input_prompt,
|
||||||
)
|
)
|
||||||
_log.debug("MLX model: Released global lock")
|
_log.debug("MLX model: Released global lock")
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
images.append(hi_res_image)
|
images.append(hi_res_image)
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
user_prompt = self._build_prompt_safe(page)
|
||||||
|
|
||||||
user_prompts.append(user_prompt)
|
user_prompts.append(user_prompt)
|
||||||
pages_with_images.append(page)
|
pages_with_images.append(page)
|
||||||
@@ -314,19 +314,25 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
num_tokens_within_batch = 0
|
num_tokens_within_batch = 0
|
||||||
|
|
||||||
# Emit predictions
|
# Emit predictions
|
||||||
for output in outputs:
|
for i, output in enumerate(outputs):
|
||||||
text = output.outputs[0].text if output.outputs else ""
|
text = output.outputs[0].text if output.outputs else ""
|
||||||
stop_reason = (
|
stop_reason = (
|
||||||
VlmStopReason.END_OF_SEQUENCE
|
VlmStopReason.END_OF_SEQUENCE
|
||||||
if output.outputs[0].stop_reason
|
if output.outputs[0].stop_reason
|
||||||
else VlmStopReason.LENGTH
|
else VlmStopReason.LENGTH
|
||||||
)
|
)
|
||||||
generated_tokens = (
|
|
||||||
[VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids]
|
generated_tokens = [
|
||||||
if self.vlm_options.track_generated_tokens
|
VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids
|
||||||
else []
|
]
|
||||||
)
|
|
||||||
num_tokens = len(generated_tokens)
|
num_tokens = len(generated_tokens)
|
||||||
|
|
||||||
|
if not self.vlm_options.track_generated_tokens:
|
||||||
|
generated_tokens = []
|
||||||
|
|
||||||
|
input_prompt = prompts[i] if self.vlm_options.track_input_prompt else None
|
||||||
|
_log.debug(f"VLM generated response carries input prompt: {input_prompt}")
|
||||||
|
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(
|
yield VlmPrediction(
|
||||||
text=decoded_text,
|
text=decoded_text,
|
||||||
@@ -334,4 +340,5 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
generated_tokens=generated_tokens,
|
generated_tokens=generated_tokens,
|
||||||
|
input_prompt=input_prompt,
|
||||||
)
|
)
|
||||||
|
|||||||
177
docs/examples/demo_layout_vlm.py
vendored
Normal file
177
docs/examples/demo_layout_vlm.py
vendored
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Demo script for the new ThreadedLayoutVlmPipeline.
|
||||||
|
|
||||||
|
This script demonstrates the usage of the experimental ThreadedLayoutVlmPipeline pipeline
|
||||||
|
that combines layout model preprocessing with VLM processing in a threaded manner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import ConversionStatus, InputFormat
|
||||||
|
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions, ResponseFormat
|
||||||
|
from docling.datamodel.vlm_model_specs import GRANITEDOCLING_TRANSFORMERS
|
||||||
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||||
|
from docling.experimental.datamodel.threaded_layout_vlm_pipeline_options import (
|
||||||
|
ThreadedLayoutVlmPipelineOptions,
|
||||||
|
)
|
||||||
|
from docling.experimental.pipeline.threaded_layout_vlm_pipeline import (
|
||||||
|
ThreadedLayoutVlmPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Demo script for the experimental ThreadedLayoutVlmPipeline"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-file",
|
||||||
|
type=str,
|
||||||
|
default="tests/data/pdf/code_and_formula.pdf",
|
||||||
|
help="Path to a PDF file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="scratch/demo_layout_vlm/",
|
||||||
|
help="Output directory for converted files",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# Can be used to read multiple pdf files under a folder
|
||||||
|
# def _get_docs(input_doc_path):
|
||||||
|
# """Yield DocumentStream objects from list of input document paths"""
|
||||||
|
# for path in input_doc_path:
|
||||||
|
# buf = BytesIO(path.read_bytes())
|
||||||
|
# stream = DocumentStream(name=path.name, stream=buf)
|
||||||
|
# yield stream
|
||||||
|
|
||||||
|
|
||||||
|
def openai_compatible_vlm_options(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
format: ResponseFormat,
|
||||||
|
hostname_and_port,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
api_key: str = "",
|
||||||
|
skip_special_tokens=False,
|
||||||
|
):
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
options = ApiVlmOptions(
|
||||||
|
url=f"http://{hostname_and_port}/v1/chat/completions", # LM studio defaults to port 1234, VLLM to 8000
|
||||||
|
params=dict(
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
skip_special_tokens=skip_special_tokens, # needed for VLLM
|
||||||
|
),
|
||||||
|
headers=headers,
|
||||||
|
prompt=prompt,
|
||||||
|
timeout=90,
|
||||||
|
scale=2.0,
|
||||||
|
temperature=temperature,
|
||||||
|
response_format=format,
|
||||||
|
)
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def demo_threaded_layout_vlm_pipeline(
|
||||||
|
input_doc_path: Path, out_dir_layout_aware: Path, use_api_vlm: bool
|
||||||
|
):
|
||||||
|
"""Demonstrate the threaded layout+VLM pipeline."""
|
||||||
|
|
||||||
|
vlm_options = GRANITEDOCLING_TRANSFORMERS.model_copy()
|
||||||
|
|
||||||
|
if use_api_vlm:
|
||||||
|
vlm_options = openai_compatible_vlm_options(
|
||||||
|
model="granite-docling-258m-mlx", # For VLLM use "ibm-granite/granite-docling-258M"
|
||||||
|
hostname_and_port="localhost:1234", # LM studio defaults to port 1234, VLLM to 8000
|
||||||
|
prompt="Convert this page to docling.",
|
||||||
|
format=ResponseFormat.DOCTAGS,
|
||||||
|
api_key="",
|
||||||
|
)
|
||||||
|
vlm_options.track_input_prompt = True
|
||||||
|
|
||||||
|
# Configure pipeline options
|
||||||
|
print("Configuring pipeline options...")
|
||||||
|
pipeline_options_layout_aware = ThreadedLayoutVlmPipelineOptions(
|
||||||
|
# VLM configuration - defaults to GRANITEDOCLING_TRANSFORMERS
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
# Layout configuration - defaults to DOCLING_LAYOUT_HERON
|
||||||
|
# Batch sizes for parallel processing
|
||||||
|
layout_batch_size=2,
|
||||||
|
vlm_batch_size=1,
|
||||||
|
# Queue configuration
|
||||||
|
queue_max_size=10,
|
||||||
|
# Image processing
|
||||||
|
images_scale=2.0,
|
||||||
|
generate_page_images=True,
|
||||||
|
enable_remote_services=use_api_vlm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create converter with the new pipeline
|
||||||
|
print("Initializing DocumentConverter (this may take a while - loading models)...")
|
||||||
|
doc_converter_layout_enhanced = DocumentConverter(
|
||||||
|
format_options={
|
||||||
|
InputFormat.PDF: PdfFormatOption(
|
||||||
|
pipeline_cls=ThreadedLayoutVlmPipeline,
|
||||||
|
pipeline_options=pipeline_options_layout_aware,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result_layout_aware = doc_converter_layout_enhanced.convert(
|
||||||
|
source=input_doc_path, raises_on_error=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if result_layout_aware.status == ConversionStatus.FAILURE:
|
||||||
|
_log.error(f"Conversion failed: {result_layout_aware.status}")
|
||||||
|
|
||||||
|
doc_filename = result_layout_aware.input.file.stem
|
||||||
|
result_layout_aware.document.save_as_json(
|
||||||
|
out_dir_layout_aware / f"{doc_filename}.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
result_layout_aware.document.save_as_html(
|
||||||
|
out_dir_layout_aware / f"{doc_filename}.html"
|
||||||
|
)
|
||||||
|
for page in result_layout_aware.pages:
|
||||||
|
_log.info("Page %s of VLM response:", page.page_no)
|
||||||
|
if page.predictions.vlm_response:
|
||||||
|
_log.info(page.predictions.vlm_response)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
try:
|
||||||
|
args = _parse_args()
|
||||||
|
_log.info(
|
||||||
|
f"Parsed arguments: input={args.input_file}, output={args.output_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_path = Path(args.input_file)
|
||||||
|
|
||||||
|
if not input_path.exists():
|
||||||
|
raise FileNotFoundError(f"Input file does not exist: {input_path}")
|
||||||
|
|
||||||
|
if input_path.suffix.lower() != ".pdf":
|
||||||
|
raise ValueError(f"Input file must be a PDF: {input_path}")
|
||||||
|
|
||||||
|
out_dir_layout_aware = Path(args.output_dir) / "layout_aware/"
|
||||||
|
out_dir_layout_aware.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
use_api_vlm = False # Set to False to use inline VLM model
|
||||||
|
|
||||||
|
demo_threaded_layout_vlm_pipeline(input_path, out_dir_layout_aware, use_api_vlm)
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
Reference in New Issue
Block a user