diff --git a/docling/models/base_model.py b/docling/models/base_model.py index c8691e17..5bf32f48 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -11,6 +11,7 @@ from docling.datamodel.base_models import ( ItemAndImageEnrichmentElement, Page, TextCell, + VlmPredictionToken, ) from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import BaseOptions @@ -49,7 +50,13 @@ class BaseLayoutModel(BasePageModel): class BaseVlmModel(BasePageModel): @abstractmethod - def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str: + def get_user_prompt(self, page: Optional[Page]) -> str: + pass + + @abstractmethod + def predict_on_page_image( + self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False + ) -> tuple[str, Optional[list[VlmPredictionToken]]]: pass diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 3513de3e..2c7b4b0a 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -38,7 +38,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi self.vlm_options = vlm_options self.scale = self.vlm_options.scale - self.max_size = self.vlm_options.max_size + # self.max_size = self.vlm_options.max_size if self.enabled: import torch diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index ddeea379..c28abe41 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -4,6 +4,8 @@ from collections.abc import Iterable from pathlib import Path from typing import Optional +from PIL import Image + from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) @@ -33,7 +35,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): self.max_tokens = vlm_options.max_new_tokens self.temperature = vlm_options.temperature self.scale = self.vlm_options.scale - self.max_size = self.vlm_options.max_size + # self.max_size = self.vlm_options.max_size if self.enabled: try: @@ -62,6 +64,55 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): self.vlm_model, self.processor = load(artifacts_path) self.config = load_config(artifacts_path) + def get_user_prompt(self, page: Optional[Page]) -> str: + if callable(self.vlm_options.prompt) and page is not None: + return self.vlm_options.prompt(page.parsed_page) + else: + user_prompt = self.vlm_options.prompt + prompt = self.apply_chat_template( + self.processor, self.config, user_prompt, num_images=1 + ) + return prompt + + def predict_on_page_image( + self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False + ) -> tuple[str, Optional[list[VlmPredictionToken]]]: + tokens = [] + output = "" + for token in self.stream_generate( + self.vlm_model, + self.processor, + prompt, + [page_image], + max_tokens=self.max_tokens, + verbose=False, + temp=self.temperature, + ): + if len(token.logprobs.shape) == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[token.token], + ) + ) + elif len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[0, token.token], + ) + ) + else: + _log.warning(f"incompatible shape for logprobs: {token.logprobs.shape}") + + output += token.text + if "" in token.text: + break + + return output, tokens + def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -73,19 +124,23 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"): assert page.size is not None - hi_res_image = page.get_image( + page_image = page.get_image( scale=self.vlm_options.scale, max_size=self.vlm_options.max_size ) - if hi_res_image is not None: - im_width, im_height = hi_res_image.size + """ + if page_image is not None: + im_width, im_height = page_image.size + """ + assert page_image is not None # populate page_tags with predicted doc tags page_tags = "" - if hi_res_image: - if hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") + if page_image: + if page_image.mode != "RGB": + page_image = page_image.convert("RGB") + """ if callable(self.vlm_options.prompt): user_prompt = self.vlm_options.prompt(page.parsed_page) else: @@ -93,11 +148,12 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): prompt = self.apply_chat_template( self.processor, self.config, user_prompt, num_images=1 ) - - start_time = time.time() - _log.debug("start generating ...") + """ + prompt = self.get_user_prompt(page) # Call model to generate: + start_time = time.time() + """ tokens: list[VlmPredictionToken] = [] output = "" @@ -105,7 +161,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): self.vlm_model, self.processor, prompt, - [hi_res_image], + [page_image], max_tokens=self.max_tokens, verbose=False, temp=self.temperature, @@ -137,13 +193,20 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin): output += token.text if "" in token.text: break + """ + output, tokens = self.predict_on_page_image( + page_image=page_image, prompt=prompt, output_tokens=True + ) generation_time = time.time() - start_time page_tags = output + """ _log.debug( f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)." ) + """ + page.predictions.vlm_response = VlmPrediction( text=page_tags, generation_time=generation_time, diff --git a/docling/models/vlm_models_inline/two_stage_vlm_model.py b/docling/models/vlm_models_inline/two_stage_vlm_model.py index 131b874e..846fe991 100644 --- a/docling/models/vlm_models_inline/two_stage_vlm_model.py +++ b/docling/models/vlm_models_inline/two_stage_vlm_model.py @@ -61,64 +61,24 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): page=page, clusters=pred_clusters ) ) - - # Define prompt structure - if callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt(page.parsed_page) - else: - user_prompt = self.vlm_options.prompt - prompt = self.formulate_prompt(user_prompt, processed_clusters) + user_prompt = self.vlm_model.get_user_prompt(page=page) + prompt = self.formulate_prompt( + user_prompt=user_prompt, clusters=processed_clusters + ) - generated_text, generation_time = self.vlm_model.predict_on_image( + start_time = time.time() + generated_text = self.vlm_model.predict_on_page_image( page_image=page_image, prompt=prompt ) page.predictions.vlm_response = VlmPrediction( - text=generated_text, - generation_time=generation_time, + text=generated_text, generation_time=time.time() - start_time ) yield page - def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str: + def formulate_prompt(self, *, user_prompt: str, clusters: list[Cluster]) -> str: """Formulate a prompt for the VLM.""" - if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW: - return user_prompt - - elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct": - _log.debug("Using specialized prompt for Phi-4") - # more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally - - user_prompt = "<|user|>" - assistant_prompt = "<|assistant|>" - prompt_suffix = "<|end|>" - - prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}" - _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}") - - return prompt - - elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT: - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "This is a page from a document.", - }, - {"type": "image"}, - {"type": "text", "text": user_prompt}, - ], - } - ] - prompt = self.processor.apply_chat_template( - messages, add_generation_prompt=False - ) - return prompt - - raise RuntimeError( - f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}." - ) + return user_prompt diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index 01be3693..aac61d8d 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -26,12 +26,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import InputFormat, Page from docling.datamodel.document import ConversionResult, InputDocument -from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions +from docling.datamodel.pipeline_options import VlmPipelineOptions from docling.datamodel.pipeline_options_vlm_model import ( ApiVlmOptions, InferenceFramework, InlineVlmOptions, ResponseFormat, + TwoStageVlmOptions, ) from docling.datamodel.settings import settings from docling.models.api_vlm_model import ApiVlmModel