diff --git a/docling/cli/main.py b/docling/cli/main.py index ae275ea9..1b623b0d 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -63,6 +63,7 @@ from docling.datamodel.vlm_model_specs import ( GRANITE_VISION_TRANSFORMERS, SMOLDOCLING_MLX, SMOLDOCLING_TRANSFORMERS, + VLM2STAGE, VlmModelType, ) from docling.document_converter import ( @@ -627,6 +628,12 @@ def convert( # noqa: C901 "To run SmolDocling faster, please install mlx-vlm:\n" "pip install mlx-vlm" ) + elif vlm_model == VlmModelType.VLM2STAGE: + pipeline_options.vlm_options = VLM2STAGE + else: + raise ValueError( + f"{vlm_model} is not of type GRANITE_VISION, GRANITE_VISION_OLLAMA, SMOLDOCLING_TRANSFORMERS or VLM2STAGE" + ) pdf_format_option = PdfFormatOption( pipeline_cls=VlmPipeline, pipeline_options=pipeline_options diff --git a/docling/datamodel/asr_model_specs.py b/docling/datamodel/asr_model_specs.py index 426b5851..b16ad8f9 100644 --- a/docling/datamodel/asr_model_specs.py +++ b/docling/datamodel/asr_model_specs.py @@ -11,12 +11,13 @@ from docling.datamodel.pipeline_options_asr_model import ( # ApiAsrOptions, InferenceAsrFramework, InlineAsrNativeWhisperOptions, - TransformersModelType, + InlineAsrOptions, + # TransformersModelType, ) _log = logging.getLogger(__name__) -WHISPER_TINY = InlineAsrNativeWhisperOptions( +WHISPER_TINY: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="tiny", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -27,7 +28,7 @@ WHISPER_TINY = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_SMALL = InlineAsrNativeWhisperOptions( +WHISPER_SMALL: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="small", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -38,7 +39,7 @@ WHISPER_SMALL = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_MEDIUM = InlineAsrNativeWhisperOptions( +WHISPER_MEDIUM: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="medium", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -49,7 +50,7 @@ WHISPER_MEDIUM = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_BASE = InlineAsrNativeWhisperOptions( +WHISPER_BASE: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="base", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -60,7 +61,7 @@ WHISPER_BASE = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_LARGE = InlineAsrNativeWhisperOptions( +WHISPER_LARGE: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="large", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, @@ -71,7 +72,7 @@ WHISPER_LARGE = InlineAsrNativeWhisperOptions( max_time_chunk=30.0, ) -WHISPER_TURBO = InlineAsrNativeWhisperOptions( +WHISPER_TURBO: InlineAsrOptions = InlineAsrNativeWhisperOptions( repo_id="turbo", inference_framework=InferenceAsrFramework.WHISPER, verbose=True, diff --git a/docling/datamodel/layout_model_specs.py b/docling/datamodel/layout_model_specs.py index b91fa7fe..ff5c8074 100644 --- a/docling/datamodel/layout_model_specs.py +++ b/docling/datamodel/layout_model_specs.py @@ -26,8 +26,6 @@ class LayoutModelConfig(BaseModel): return self.repo_id.replace("/", "--") -# HuggingFace Layout Models - # Default Docling Layout Model DOCLING_LAYOUT_V2 = LayoutModelConfig( name="docling_layout_v2", diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 06169fb8..3a9068c4 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -12,10 +12,16 @@ from pydantic import ( ) from typing_extensions import deprecated -from docling.datamodel import asr_model_specs - # Import the following for backwards compatibility from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions +from docling.datamodel.asr_model_specs import ( + WHISPER_BASE, + WHISPER_LARGE, + WHISPER_MEDIUM, + WHISPER_SMALL, + WHISPER_TINY, + WHISPER_TURBO, +) from docling.datamodel.layout_model_specs import ( DOCLING_LAYOUT_EGRET_LARGE, DOCLING_LAYOUT_EGRET_MEDIUM, @@ -33,6 +39,7 @@ from docling.datamodel.pipeline_options_vlm_model import ( InferenceFramework, InlineVlmOptions, ResponseFormat, + TwoStageVlmOptions, ) from docling.datamodel.vlm_model_specs import ( GRANITE_VISION_OLLAMA as granite_vision_vlm_ollama_conversion_options, @@ -270,8 +277,9 @@ class VlmPipelineOptions(PaginatedPipelineOptions): False # (To be used with vlms, or other generative models) ) # If True, text from backend will be used instead of generated text - vlm_options: Union[InlineVlmOptions, ApiVlmOptions] = ( + vlm_options: Union[InlineVlmOptions, ApiVlmOptions, TwoStageVlmOptions] = ( smoldocling_vlm_conversion_options + # SMOLDOCLING_TRANSFORMERS ) @@ -286,7 +294,7 @@ class LayoutOptions(BaseModel): class AsrPipelineOptions(PipelineOptions): - asr_options: Union[InlineAsrOptions] = asr_model_specs.WHISPER_TINY + asr_options: Union[InlineAsrOptions] = WHISPER_TINY artifacts_path: Optional[Union[Path, str]] = None diff --git a/docling/datamodel/pipeline_options_asr_model.py b/docling/datamodel/pipeline_options_asr_model.py index 20e2e453..f26aad76 100644 --- a/docling/datamodel/pipeline_options_asr_model.py +++ b/docling/datamodel/pipeline_options_asr_model.py @@ -5,10 +5,11 @@ from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated from docling.datamodel.accelerator_options import AcceleratorDevice -from docling.datamodel.pipeline_options_vlm_model import ( - # InferenceFramework, - TransformersModelType, -) + +# from docling.datamodel.pipeline_options_vlm_model import ( +# InferenceFramework, +# TransformersModelType, +# ) class BaseAsrOptions(BaseModel): diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index bcea2493..66c97ca4 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -6,6 +6,9 @@ from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated from docling.datamodel.accelerator_options import AcceleratorDevice +from docling.datamodel.layout_model_specs import ( + LayoutModelConfig, +) class BaseVlmOptions(BaseModel): @@ -87,3 +90,12 @@ class ApiVlmOptions(BaseVlmOptions): timeout: float = 60 concurrency: int = 1 response_format: ResponseFormat + + +class TwoStageVlmOptions(BaseModel): + kind: Literal["inline_two_stage_model_options"] = "inline_two_stage_model_options" + + response_format: ResponseFormat # final response of the VLM + + layout_options: LayoutModelConfig # = DOCLING_LAYOUT_V2 + vlm_options: Union[InlineVlmOptions, ApiVlmOptions] # = SMOLDOCLING_TRANSFORMERS diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index 5045c846..906e4e9c 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -6,12 +6,17 @@ from pydantic import ( ) from docling.datamodel.accelerator_options import AcceleratorDevice +from docling.datamodel.layout_model_specs import ( + DOCLING_LAYOUT_HERON, + DOCLING_LAYOUT_V2, +) from docling.datamodel.pipeline_options_vlm_model import ( ApiVlmOptions, InferenceFramework, InlineVlmOptions, ResponseFormat, TransformersModelType, + TwoStageVlmOptions, ) _log = logging.getLogger(__name__) @@ -137,8 +142,15 @@ GEMMA3_27B_MLX = InlineVlmOptions( temperature=0.0, ) +VLM2STAGE = TwoStageVlmOptions( + vlm_options=SMOLDOCLING_MLX, + layout_options=DOCLING_LAYOUT_HERON, + response_format=SMOLDOCLING_MLX.response_format, +) + class VlmModelType(str, Enum): SMOLDOCLING = "smoldocling" GRANITE_VISION = "granite_vision" GRANITE_VISION_OLLAMA = "granite_vision_ollama" + VLM2STAGE = "vlm2stage" diff --git a/docling/models/base_model.py b/docling/models/base_model.py index b0a43f40..5bf32f48 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -3,9 +3,16 @@ from collections.abc import Iterable from typing import Generic, Optional, Protocol, Type from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem +from PIL import Image from typing_extensions import TypeVar -from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page +from docling.datamodel.base_models import ( + Cluster, + ItemAndImageEnrichmentElement, + Page, + TextCell, + VlmPredictionToken, +) from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import BaseOptions from docling.datamodel.settings import settings @@ -19,6 +26,9 @@ class BaseModelWithOptions(Protocol): class BasePageModel(ABC): + scale: float # scale with which the page-image needs to be created (dpi = 72*scale) + max_size: int # max size of width/height of page-image + @abstractmethod def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] @@ -26,6 +36,30 @@ class BasePageModel(ABC): pass +class BaseLayoutModel(BasePageModel): + @abstractmethod + def predict_on_page_image(self, *, page_image: Image.Image) -> list[Cluster]: + pass + + @abstractmethod + def postprocess_on_page_image( + self, *, page: Page, clusters: list[Cluster] + ) -> tuple[Page, list[Cluster], list[TextCell]]: + pass + + +class BaseVlmModel(BasePageModel): + @abstractmethod + def get_user_prompt(self, page: Optional[Page]) -> str: + pass + + @abstractmethod + def predict_on_page_image( + self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False + ) -> tuple[str, Optional[list[VlmPredictionToken]]]: + pass + + EnrichElementT = TypeVar("EnrichElementT", default=NodeItem) diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 05a86f31..2b7947da 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -7,6 +7,7 @@ from typing import Optional import numpy as np from docling_core.types.doc import DocItemLabel +from docling_core.types.doc.page import TextCell from PIL import Image from docling.datamodel.accelerator_options import AcceleratorOptions @@ -15,7 +16,7 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.layout_model_specs import DOCLING_LAYOUT_V2, LayoutModelConfig from docling.datamodel.pipeline_options import LayoutOptions from docling.datamodel.settings import settings -from docling.models.base_model import BasePageModel +from docling.models.base_model import BaseLayoutModel, BasePageModel from docling.models.utils.hf_model_download import download_hf_model from docling.utils.accelerator_utils import decide_device from docling.utils.layout_postprocessor import LayoutPostprocessor @@ -25,7 +26,7 @@ from docling.utils.visualization import draw_clusters _log = logging.getLogger(__name__) -class LayoutModel(BasePageModel): +class LayoutModel(BaseLayoutModel): TEXT_ELEM_LABELS = [ DocItemLabel.TEXT, DocItemLabel.FOOTNOTE, @@ -158,6 +159,7 @@ class LayoutModel(BasePageModel): page_image = page.get_image(scale=1.0) assert page_image is not None + """ clusters = [] for ix, pred_item in enumerate( self.layout_predictor.predict(page_image) @@ -176,14 +178,18 @@ class LayoutModel(BasePageModel): cells=[], ) clusters.append(cluster) + """ + predicted_clusters = self.predict_on_page_image( + page_image=page_image + ) if settings.debug.visualize_raw_layout: self.draw_clusters_and_cells_side_by_side( - conv_res, page, clusters, mode_prefix="raw" + conv_res, page, predicted_clusters, mode_prefix="raw" ) # Apply postprocessing - + """ processed_clusters, processed_cells = LayoutPostprocessor( page, clusters, self.options ).postprocess() @@ -210,6 +216,30 @@ class LayoutModel(BasePageModel): page.predictions.layout = LayoutPrediction( clusters=processed_clusters ) + """ + page, processed_clusters, processed_cells = ( + self.postprocess_on_page_image( + page=page, clusters=predicted_clusters + ) + ) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "Mean of empty slice|invalid value encountered in scalar divide", + RuntimeWarning, + "numpy", + ) + + conv_res.confidence.pages[page.page_no].layout_score = float( + np.mean([c.confidence for c in processed_clusters]) + ) + + conv_res.confidence.pages[page.page_no].ocr_score = float( + np.mean( + [c.confidence for c in processed_cells if c.from_ocr] + ) + ) if settings.debug.visualize_layout: self.draw_clusters_and_cells_side_by_side( @@ -217,3 +247,34 @@ class LayoutModel(BasePageModel): ) yield page + + def predict_on_page_image(self, *, page_image: Image.Image) -> list[Cluster]: + pred_items = self.layout_predictor.predict(page_image) + + clusters = [] + for ix, pred_item in enumerate(pred_items): + label = DocItemLabel( + pred_item["label"].lower().replace(" ", "_").replace("-", "_") + ) # Temporary, until docling-ibm-model uses docling-core types + cluster = Cluster( + id=ix, + label=label, + confidence=pred_item["confidence"], + bbox=BoundingBox.model_validate(pred_item), + cells=[], + ) + clusters.append(cluster) + + return clusters + + def postprocess_on_page_image( + self, *, page: Page, clusters: list[Cluster] + ) -> tuple[Page, list[Cluster], list[TextCell]]: + processed_clusters, processed_cells = LayoutPostprocessor( + page, clusters, self.options + ).postprocess() + # Note: LayoutPostprocessor updates page.cells and page.parsed_page internally + + page.predictions.layout = LayoutPrediction(clusters=processed_clusters) + + return page, processed_clusters, processed_cells diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index d84925dd..5434ee50 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -5,17 +5,19 @@ from collections.abc import Iterable from pathlib import Path from typing import Any, Optional +from PIL import Image + from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) -from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ( InlineVlmOptions, TransformersModelType, TransformersPromptStyle, ) -from docling.models.base_model import BasePageModel +from docling.models.base_model import BasePageModel, BaseVlmModel from docling.models.utils.hf_model_download import ( HuggingFaceModelDownloadMixin, ) @@ -25,7 +27,7 @@ from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) -class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): +class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixin): def __init__( self, enabled: bool, @@ -37,6 +39,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix self.vlm_options = vlm_options + self.scale = self.vlm_options.scale + self.max_size = 512 + if isinstance(self.vlm_options.max_size, int): + self.max_size = self.vlm_options.max_size + if self.enabled: import torch from transformers import ( @@ -119,6 +126,43 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix # Load generation config self.generation_config = GenerationConfig.from_pretrained(artifacts_path) + def get_user_prompt(self, page: Optional[Page]) -> str: + # Define prompt structure + user_prompt = "" + if callable(self.vlm_options.prompt) and page is not None: + user_prompt = self.vlm_options.prompt(page.parsed_page) + elif isinstance(self.vlm_options.prompt, str): + user_prompt = self.vlm_options.prompt + + prompt = self.formulate_prompt(user_prompt) + return prompt + + def predict_on_page_image( + self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False + ) -> tuple[str, Optional[list[VlmPredictionToken]]]: + output = "" + + inputs = self.processor( + text=prompt, images=[page_image], return_tensors="pt" + ).to(self.device) + + # Call model to generate: + generated_ids = self.vlm_model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + use_cache=self.use_cache, + temperature=self.temperature, + generation_config=self.generation_config, + **self.vlm_options.extra_generation_config, + ) + + output = self.processor.batch_decode( + generated_ids[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=False, + )[0] + + return output, [] + def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -130,22 +174,29 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix with TimeRecorder(conv_res, "vlm"): assert page.size is not None - hi_res_image = page.get_image( + page_image = page.get_image( scale=self.vlm_options.scale, max_size=self.vlm_options.max_size ) + assert page_image is not None + # Define prompt structure + """ if callable(self.vlm_options.prompt): user_prompt = self.vlm_options.prompt(page.parsed_page) else: user_prompt = self.vlm_options.prompt prompt = self.formulate_prompt(user_prompt) - - inputs = self.processor( - text=prompt, images=[hi_res_image], return_tensors="pt" - ).to(self.device) + """ + prompt = self.get_user_prompt(page=page) start_time = time.time() + + """ + inputs = self.processor( + text=prompt, images=[page_image], return_tensors="pt" + ).to(self.device) + # Call model to generate: generated_ids = self.vlm_model.generate( **inputs, @@ -166,9 +217,14 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix _log.debug( f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." ) + """ + generated_text = self.predict_on_page_image( + page_image=page_image, prompt=prompt, output_tokens=False + ) + page.predictions.vlm_response = VlmPrediction( - text=generated_texts, - generation_time=generation_time, + text=generated_text, + generation_time=time.time() - start_time, ) yield page diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 647ce531..fa28de7f 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -4,13 +4,15 @@ from collections.abc import Iterable from pathlib import Path from typing import Optional +from PIL import Image + from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions -from docling.models.base_model import BasePageModel +from docling.models.base_model import BasePageModel, BaseVlmModel from docling.models.utils.hf_model_download import ( HuggingFaceModelDownloadMixin, ) @@ -19,7 +21,7 @@ from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) -class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): +class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): def __init__( self, enabled: bool, @@ -28,10 +30,15 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): vlm_options: InlineVlmOptions, ): self.enabled = enabled - self.vlm_options = vlm_options + self.max_tokens = vlm_options.max_new_tokens self.temperature = vlm_options.temperature + self.scale = self.vlm_options.scale + + self.max_size = 512 + if isinstance(self.vlm_options.max_size, int): + self.max_size = self.vlm_options.max_size if self.enabled: try: @@ -60,6 +67,55 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): self.vlm_model, self.processor = load(artifacts_path) self.config = load_config(artifacts_path) + def get_user_prompt(self, page: Optional[Page]) -> str: + if callable(self.vlm_options.prompt) and page is not None: + return self.vlm_options.prompt(page.parsed_page) + else: + user_prompt = self.vlm_options.prompt + prompt = self.apply_chat_template( + self.processor, self.config, user_prompt, num_images=1 + ) + return prompt + + def predict_on_page_image( + self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False + ) -> tuple[str, Optional[list[VlmPredictionToken]]]: + tokens = [] + output = "" + for token in self.stream_generate( + self.vlm_model, + self.processor, + prompt, + [page_image], + max_tokens=self.max_tokens, + verbose=False, + temp=self.temperature, + ): + if len(token.logprobs.shape) == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[token.token], + ) + ) + elif len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[0, token.token], + ) + ) + else: + _log.warning(f"incompatible shape for logprobs: {token.logprobs.shape}") + + output += token.text + if "" in token.text: + break + + return output, tokens + def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -71,19 +127,23 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"): assert page.size is not None - hi_res_image = page.get_image( + page_image = page.get_image( scale=self.vlm_options.scale, max_size=self.vlm_options.max_size ) - if hi_res_image is not None: - im_width, im_height = hi_res_image.size + """ + if page_image is not None: + im_width, im_height = page_image.size + """ + assert page_image is not None # populate page_tags with predicted doc tags page_tags = "" - if hi_res_image: - if hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") + if page_image: + if page_image.mode != "RGB": + page_image = page_image.convert("RGB") + """ if callable(self.vlm_options.prompt): user_prompt = self.vlm_options.prompt(page.parsed_page) else: @@ -91,11 +151,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): prompt = self.apply_chat_template( self.processor, self.config, user_prompt, num_images=1 ) - - start_time = time.time() - _log.debug("start generating ...") + """ + prompt = self.get_user_prompt(page) # Call model to generate: + start_time = time.time() + """ tokens: list[VlmPredictionToken] = [] output = "" @@ -103,7 +164,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): self.vlm_model, self.processor, prompt, - [hi_res_image], + [page_image], max_tokens=self.max_tokens, verbose=False, temp=self.temperature, @@ -135,13 +196,20 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): output += token.text if "" in token.text: break + """ + output, tokens = self.predict_on_page_image( + page_image=page_image, prompt=prompt, output_tokens=True + ) generation_time = time.time() - start_time page_tags = output + """ _log.debug( f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)." ) + """ + page.predictions.vlm_response = VlmPrediction( text=page_tags, generation_time=generation_time, diff --git a/docling/models/vlm_models_inline/two_stage_vlm_model.py b/docling/models/vlm_models_inline/two_stage_vlm_model.py new file mode 100644 index 00000000..2bf5958f --- /dev/null +++ b/docling/models/vlm_models_inline/two_stage_vlm_model.py @@ -0,0 +1,119 @@ +import importlib.metadata +import logging +import time +from collections.abc import Iterable +from pathlib import Path +from typing import Any, Optional + +from docling.datamodel.accelerator_options import ( + AcceleratorOptions, +) +from docling.datamodel.base_models import Cluster, Page, VlmPrediction +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options_vlm_model import ( + InlineVlmOptions, + TransformersModelType, + TransformersPromptStyle, +) +from docling.models.base_model import BaseLayoutModel, BasePageModel, BaseVlmModel +from docling.models.layout_model import LayoutModel +from docling.models.utils.hf_model_download import ( + HuggingFaceModelDownloadMixin, +) +from docling.utils.accelerator_utils import decide_device +from docling.utils.profiling import TimeRecorder + +_log = logging.getLogger(__name__) + + +class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): + def __init__( + self, + *, + layout_model: BaseLayoutModel, + vlm_model: BaseVlmModel, + ): + self.layout_model = layout_model + self.vlm_model = vlm_model + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + for page in page_batch: + assert page._backend is not None + if not page._backend.is_valid(): + yield page + else: + with TimeRecorder(conv_res, "two-staged-vlm"): + assert page.size is not None + + page_image = page.get_image( + scale=self.vlm_model.scale, max_size=self.vlm_model.max_size + ) + assert page_image is not None + + pred_clusters = self.layout_model.predict_on_page_image( + page_image=page_image + ) + + page, processed_clusters, processed_cells = ( + self.layout_model.postprocess_on_page_image( + page=page, clusters=pred_clusters + ) + ) + + user_prompt = self.vlm_model.get_user_prompt(page=page) + prompt = self.formulate_prompt( + user_prompt=user_prompt, + clusters=processed_clusters, + image_width=page_image.width, + image_height=page_image.height, + ) + + start_time = time.time() + generated_text, generated_tokens = ( + self.vlm_model.predict_on_page_image( + page_image=page_image, prompt=prompt + ) + ) + print("generated-text: \n", generated_text, "\n") + page.predictions.vlm_response = VlmPrediction( + text=generated_text, + generation_time=time.time() - start_time, + generated_tokens=generated_tokens, + ) + exit(-1) + + yield page + + def formulate_prompt( + self, + *, + user_prompt: str, + clusters: list[Cluster], + image_width: int, + image_height: int, + vlm_width: int = 512, + vlm_height: int = 512, + ) -> str: + """Formulate a prompt for the VLM.""" + + known_clusters = ["here is a list of unsorted text-blocks:", ""] + for cluster in clusters: + print(" => ", cluster) + + loc_l = f"" + loc_b = f"" + loc_r = f"" + loc_t = f"" + + known_clusters.append( + f"<{cluster.label}>{loc_l}{loc_b}{loc_r}{loc_t}" + ) + + known_clusters.append("") + + user_prompt = "\n".join(known_clusters) + f"\n\n{user_prompt}" + print("user-prompt: ", user_prompt, "\n") + + return user_prompt diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index ab474fab..1c94d977 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -26,21 +26,24 @@ from docling.backend.md_backend import MarkdownDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import InputFormat, Page from docling.datamodel.document import ConversionResult, InputDocument -from docling.datamodel.pipeline_options import ( - VlmPipelineOptions, -) +from docling.datamodel.pipeline_options import LayoutOptions, VlmPipelineOptions from docling.datamodel.pipeline_options_vlm_model import ( ApiVlmOptions, InferenceFramework, InlineVlmOptions, ResponseFormat, + TwoStageVlmOptions, ) from docling.datamodel.settings import settings from docling.models.api_vlm_model import ApiVlmModel +from docling.models.layout_model import LayoutModel from docling.models.vlm_models_inline.hf_transformers_model import ( HuggingFaceTransformersVlmModel, ) from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel +from docling.models.vlm_models_inline.two_stage_vlm_model import ( + TwoStageVlmModel, +) from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -107,6 +110,53 @@ class VlmPipeline(PaginatedPipeline): raise ValueError( f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}" ) + elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions): + twostagevlm_options = cast( + TwoStageVlmOptions, self.pipeline_options.vlm_options + ) + + stage_1_options = twostagevlm_options.layout_options + stage_2_options = twostagevlm_options.vlm_options + + layout_model = LayoutModel( + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + options=LayoutOptions( + create_orphan_clusters=False, model_spec=stage_1_options + ), + ) + + if ( + isinstance(stage_2_options, InlineVlmOptions) + and stage_2_options.inference_framework == InferenceFramework.MLX + ): + vlm_model_mlx = HuggingFaceMlxModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + vlm_options=stage_2_options, + ) + self.build_pipe = [ + TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model_mlx) + ] + elif ( + isinstance(stage_2_options, InlineVlmOptions) + and stage_2_options.inference_framework + == InferenceFramework.TRANSFORMERS + ): + vlm_model_hf = HuggingFaceTransformersVlmModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + vlm_options=stage_2_options, + ) + self.build_pipe = [ + TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model_hf) + ] + else: + raise ValueError( + f"Could not instantiate the right type of VLM pipeline: {stage_2_options}" + ) self.enrichment_pipe = [ # Other models working on `NodeItem` elements in the DoclingDocument diff --git a/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.json b/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.json index dd51e390..e938e2d7 100644 --- a/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.json +++ b/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.json @@ -213,10 +213,10 @@ "prov": [ { "bbox": [ - 139.66741943359375, + 139.6674041748047, 322.5054626464844, 475.00927734375, - 454.45458984375 + 454.4546203613281 ], "page": 1, "span": [ diff --git a/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.pages.json b/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.pages.json index 3010fbb6..3c219d95 100644 --- a/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.pages.json +++ b/tests/data/groundtruth/docling_v1/2305.03393v1-pg9.pages.json @@ -2705,7 +2705,7 @@ "b": 102.78223000000003, "coord_origin": "TOPLEFT" }, - "confidence": 0.9373534917831421, + "confidence": 0.9373533129692078, "cells": [ { "index": 0, @@ -2745,7 +2745,7 @@ "b": 102.78223000000003, "coord_origin": "TOPLEFT" }, - "confidence": 0.8858680725097656, + "confidence": 0.8858679533004761, "cells": [ { "index": 1, @@ -2785,7 +2785,7 @@ "b": 152.90697999999998, "coord_origin": "TOPLEFT" }, - "confidence": 0.9806433916091919, + "confidence": 0.9806435108184814, "cells": [ { "index": 2, @@ -2940,7 +2940,7 @@ "b": 255.42400999999995, "coord_origin": "TOPLEFT" }, - "confidence": 0.98504239320755, + "confidence": 0.9850425124168396, "cells": [ { "index": 7, @@ -3155,7 +3155,7 @@ "b": 327.98218, "coord_origin": "TOPLEFT" }, - "confidence": 0.9591909050941467, + "confidence": 0.9591907262802124, "cells": [ { "index": 15, @@ -3339,8 +3339,8 @@ "id": 0, "label": "table", "bbox": { - "l": 139.66741943359375, - "t": 337.54541015625, + "l": 139.6674041748047, + "t": 337.5453796386719, "r": 475.00927734375, "b": 469.4945373535156, "coord_origin": "TOPLEFT" @@ -7846,7 +7846,7 @@ "b": 518.17419, "coord_origin": "TOPLEFT" }, - "confidence": 0.9589294195175171, + "confidence": 0.9589295387268066, "cells": [ { "index": 91, @@ -7911,7 +7911,7 @@ "b": 618.3, "coord_origin": "TOPLEFT" }, - "confidence": 0.9849975109100342, + "confidence": 0.9849976301193237, "cells": [ { "index": 93, @@ -8243,8 +8243,8 @@ "id": 0, "label": "table", "bbox": { - "l": 139.66741943359375, - "t": 337.54541015625, + "l": 139.6674041748047, + "t": 337.5453796386719, "r": 475.00927734375, "b": 469.4945373535156, "coord_origin": "TOPLEFT" @@ -13641,7 +13641,7 @@ "b": 102.78223000000003, "coord_origin": "TOPLEFT" }, - "confidence": 0.9373534917831421, + "confidence": 0.9373533129692078, "cells": [ { "index": 0, @@ -13687,7 +13687,7 @@ "b": 102.78223000000003, "coord_origin": "TOPLEFT" }, - "confidence": 0.8858680725097656, + "confidence": 0.8858679533004761, "cells": [ { "index": 1, @@ -13733,7 +13733,7 @@ "b": 152.90697999999998, "coord_origin": "TOPLEFT" }, - "confidence": 0.9806433916091919, + "confidence": 0.9806435108184814, "cells": [ { "index": 2, @@ -13900,7 +13900,7 @@ "b": 255.42400999999995, "coord_origin": "TOPLEFT" }, - "confidence": 0.98504239320755, + "confidence": 0.9850425124168396, "cells": [ { "index": 7, @@ -14121,7 +14121,7 @@ "b": 327.98218, "coord_origin": "TOPLEFT" }, - "confidence": 0.9591909050941467, + "confidence": 0.9591907262802124, "cells": [ { "index": 15, @@ -14311,8 +14311,8 @@ "id": 0, "label": "table", "bbox": { - "l": 139.66741943359375, - "t": 337.54541015625, + "l": 139.6674041748047, + "t": 337.5453796386719, "r": 475.00927734375, "b": 469.4945373535156, "coord_origin": "TOPLEFT" @@ -19701,7 +19701,7 @@ "b": 518.17419, "coord_origin": "TOPLEFT" }, - "confidence": 0.9589294195175171, + "confidence": 0.9589295387268066, "cells": [ { "index": 91, @@ -19772,7 +19772,7 @@ "b": 618.3, "coord_origin": "TOPLEFT" }, - "confidence": 0.9849975109100342, + "confidence": 0.9849976301193237, "cells": [ { "index": 93, @@ -20116,7 +20116,7 @@ "b": 152.90697999999998, "coord_origin": "TOPLEFT" }, - "confidence": 0.9806433916091919, + "confidence": 0.9806435108184814, "cells": [ { "index": 2, @@ -20283,7 +20283,7 @@ "b": 255.42400999999995, "coord_origin": "TOPLEFT" }, - "confidence": 0.98504239320755, + "confidence": 0.9850425124168396, "cells": [ { "index": 7, @@ -20504,7 +20504,7 @@ "b": 327.98218, "coord_origin": "TOPLEFT" }, - "confidence": 0.9591909050941467, + "confidence": 0.9591907262802124, "cells": [ { "index": 15, @@ -20694,8 +20694,8 @@ "id": 0, "label": "table", "bbox": { - "l": 139.66741943359375, - "t": 337.54541015625, + "l": 139.6674041748047, + "t": 337.5453796386719, "r": 475.00927734375, "b": 469.4945373535156, "coord_origin": "TOPLEFT" @@ -26084,7 +26084,7 @@ "b": 518.17419, "coord_origin": "TOPLEFT" }, - "confidence": 0.9589294195175171, + "confidence": 0.9589295387268066, "cells": [ { "index": 91, @@ -26155,7 +26155,7 @@ "b": 618.3, "coord_origin": "TOPLEFT" }, - "confidence": 0.9849975109100342, + "confidence": 0.9849976301193237, "cells": [ { "index": 93, @@ -26499,7 +26499,7 @@ "b": 102.78223000000003, "coord_origin": "TOPLEFT" }, - "confidence": 0.9373534917831421, + "confidence": 0.9373533129692078, "cells": [ { "index": 0, @@ -26545,7 +26545,7 @@ "b": 102.78223000000003, "coord_origin": "TOPLEFT" }, - "confidence": 0.8858680725097656, + "confidence": 0.8858679533004761, "cells": [ { "index": 1,