From 4c0bc61e54cb71a20f48cb01cc6c827e6ea15a50 Mon Sep 17 00:00:00 2001 From: Peter Staar Date: Wed, 14 May 2025 05:31:54 +0200 Subject: [PATCH] refactoring the download_model Signed-off-by: Peter Staar --- docling/datamodel/pipeline_options.py | 4 +- docling/models/hf_vlm_model.py | 153 +----------------- .../models/hf_vlm_models/hf_vlm_mlx_model.py | 8 +- .../hf_vlm_model_AutoModelForCausalLM.py | 9 +- .../hf_vlm_model_AutoModelForVision2Seq.py | 8 +- ...vlm_model_LlavaForConditionalGeneration.py | 46 +++--- .../models/hf_vlm_models/pixtral_12b_2409.py | 33 ---- docling/pipeline/vlm_pipeline.py | 16 +- docs/examples/minimal_vlm_pipeline.py | 8 +- 9 files changed, 64 insertions(+), 221 deletions(-) delete mode 100644 docling/models/hf_vlm_models/pixtral_12b_2409.py diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index f6d127ca..edbf7b58 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -269,7 +269,9 @@ class InferenceFramework(str, Enum): OPENAI = "openai" TRANSFORMERS_AutoModelForVision2Seq = "transformers-AutoModelForVision2Seq" TRANSFORMERS_AutoModelForCausalLM = "transformers-AutoModelForCausalLM" - TRANSFORMERS_LlavaForConditionalGeneration = "transformers-LlavaForConditionalGeneration" + TRANSFORMERS_LlavaForConditionalGeneration = ( + "transformers-LlavaForConditionalGeneration" + ) class HuggingFaceVlmOptions(BaseVlmOptions): diff --git a/docling/models/hf_vlm_model.py b/docling/models/hf_vlm_model.py index 79518f0f..9ae45b2f 100644 --- a/docling/models/hf_vlm_model.py +++ b/docling/models/hf_vlm_model.py @@ -17,81 +17,7 @@ from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) -class HuggingFaceVlmModel(BasePageModel): - """ - def __init__( - self, - enabled: bool, - artifacts_path: Optional[Path], - accelerator_options: AcceleratorOptions, - vlm_options: HuggingFaceVlmOptions, - ): - self.enabled = enabled - - self.vlm_options = vlm_options - - if self.enabled: - import torch - from transformers import ( # type: ignore - AutoModelForVision2Seq, - AutoProcessor, - BitsAndBytesConfig, - ) - - device = decide_device(accelerator_options.device) - self.device = device - - _log.debug(f"Available device for HuggingFace VLM: {device}") - - repo_cache_folder = vlm_options.repo_id.replace("/", "--") - - # PARAMETERS: - if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) - elif (artifacts_path / repo_cache_folder).exists(): - artifacts_path = artifacts_path / repo_cache_folder - - self.param_question = vlm_options.prompt # "Perform Layout Analysis." - self.param_quantization_config = BitsAndBytesConfig( - load_in_8bit=vlm_options.load_in_8bit, # True, - llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0 - ) - self.param_quantized = vlm_options.quantized # False - - self.processor = AutoProcessor.from_pretrained( - artifacts_path, - # trust_remote_code=True, - ) - if not self.param_quantized: - self.vlm_model = AutoModelForVision2Seq.from_pretrained( - artifacts_path, - device_map=self.device, - torch_dtype=torch.bfloat16, - _attn_implementation=( - "flash_attention_2" - if self.device.startswith("cuda") - and accelerator_options.cuda_use_flash_attention2 - else "eager" - ), - # trust_remote_code=True, - ) # .to(self.device) - - else: - self.vlm_model = AutoModelForVision2Seq.from_pretrained( - artifacts_path, - device_map=self.device, - torch_dtype="auto", - quantization_config=self.param_quantization_config, - _attn_implementation=( - "flash_attention_2" - if self.device.startswith("cuda") - and accelerator_options.cuda_use_flash_attention2 - else "eager" - ), - # trust_remote_code=True, - ) # .to(self.device) - """ - +class HuggingFaceVlmModel: @staticmethod def download_models( repo_id: str, @@ -112,80 +38,3 @@ class HuggingFaceVlmModel(BasePageModel): ) return Path(download_path) - - """ - def __call__( - self, conv_res: ConversionResult, page_batch: Iterable[Page] - ) -> Iterable[Page]: - for page in page_batch: - assert page._backend is not None - if not page._backend.is_valid(): - yield page - else: - with TimeRecorder(conv_res, "vlm"): - assert page.size is not None - - hi_res_image = page.get_image(scale=2.0) # 144dpi - # hi_res_image = page.get_image(scale=1.0) # 72dpi - - if hi_res_image is not None: - im_width, im_height = hi_res_image.size - - # populate page_tags with predicted doc tags - page_tags = "" - - if hi_res_image: - if hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") - - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "This is a page from a document.", - }, - {"type": "image"}, - {"type": "text", "text": self.param_question}, - ], - } - ] - prompt = self.processor.apply_chat_template( - messages, add_generation_prompt=False - ) - inputs = self.processor( - text=prompt, images=[hi_res_image], return_tensors="pt" - ) - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - start_time = time.time() - # Call model to generate: - generated_ids = self.vlm_model.generate( - **inputs, max_new_tokens=4096, use_cache=True - ) - - generation_time = time.time() - start_time - generated_texts = self.processor.batch_decode( - generated_ids[:, inputs["input_ids"].shape[1] :], - skip_special_tokens=False, - )[0] - - num_tokens = len(generated_ids[0]) - page_tags = generated_texts - - _log.debug( - f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." - ) - - # inference_time = time.time() - start_time - # tokens_per_second = num_tokens / generation_time - # print("") - # print(f"Page Inference Time: {inference_time:.2f} seconds") - # print(f"Total tokens on page: {num_tokens:.2f}") - # print(f"Tokens/sec: {tokens_per_second:.2f}") - # print("") - page.predictions.vlm_response = VlmPrediction(text=page_tags) - - yield page - """ diff --git a/docling/models/hf_vlm_models/hf_vlm_mlx_model.py b/docling/models/hf_vlm_models/hf_vlm_mlx_model.py index 63f8fc95..4dc90bf7 100644 --- a/docling/models/hf_vlm_models/hf_vlm_mlx_model.py +++ b/docling/models/hf_vlm_models/hf_vlm_mlx_model.py @@ -11,6 +11,7 @@ from docling.datamodel.pipeline_options import ( HuggingFaceVlmOptions, ) from docling.models.base_model import BasePageModel +from docling.models.hf_vlm_model import HuggingFaceVlmModel from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) @@ -44,7 +45,10 @@ class HuggingFaceMlxModel(BasePageModel): # PARAMETERS: if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) + # artifacts_path = self.download_models(self.vlm_options.repo_id) + artifacts_path = HuggingFaceVlmModel.download_models( + self.vlm_options.repo_id + ) elif (artifacts_path / repo_cache_folder).exists(): artifacts_path = artifacts_path / repo_cache_folder @@ -54,6 +58,7 @@ class HuggingFaceMlxModel(BasePageModel): self.vlm_model, self.processor = load(artifacts_path) self.config = load_config(artifacts_path) + """ @staticmethod def download_models( repo_id: str, @@ -74,6 +79,7 @@ class HuggingFaceMlxModel(BasePageModel): ) return Path(download_path) + """ def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] diff --git a/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py b/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py index 5cfe2006..287da800 100644 --- a/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py +++ b/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py @@ -11,6 +11,7 @@ from docling.datamodel.pipeline_options import ( HuggingFaceVlmOptions, ) from docling.models.base_model import BasePageModel +from docling.models.hf_vlm_model import HuggingFaceVlmModel from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import TimeRecorder @@ -30,7 +31,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): self.trust_remote_code = True self.vlm_options = vlm_options - print(self.vlm_options) if self.enabled: import torch @@ -49,7 +49,10 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): # PARAMETERS: if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) + # artifacts_path = self.download_models(self.vlm_options.repo_id) + artifacts_path = HuggingFaceVlmModel.download_models( + self.vlm_options.repo_id + ) elif (artifacts_path / repo_cache_folder).exists(): artifacts_path = artifacts_path / repo_cache_folder @@ -99,6 +102,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): # Load generation config self.generation_config = GenerationConfig.from_pretrained(model_path) + """ @staticmethod def download_models( repo_id: str, @@ -119,6 +123,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): ) return Path(download_path) + """ def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] diff --git a/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForVision2Seq.py b/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForVision2Seq.py index 9d25e7c5..b4313a5f 100644 --- a/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForVision2Seq.py +++ b/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForVision2Seq.py @@ -11,6 +11,7 @@ from docling.datamodel.pipeline_options import ( HuggingFaceVlmOptions, ) from docling.models.base_model import BasePageModel +from docling.models.hf_vlm_model import HuggingFaceVlmModel from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import TimeRecorder @@ -46,7 +47,10 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel): # PARAMETERS: if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) + # artifacts_path = self.download_models(self.vlm_options.repo_id) + artifacts_path = HuggingFaceVlmModel.download_models( + self.vlm_options.repo_id + ) elif (artifacts_path / repo_cache_folder).exists(): artifacts_path = artifacts_path / repo_cache_folder @@ -90,6 +94,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel): # trust_remote_code=True, ) # .to(self.device) + """ @staticmethod def download_models( repo_id: str, @@ -110,6 +115,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel): ) return Path(download_path) + """ def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] diff --git a/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py b/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py index 897ec9f6..0f1569ef 100644 --- a/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py +++ b/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py @@ -4,6 +4,8 @@ from collections.abc import Iterable from pathlib import Path from typing import Optional +from transformers import AutoProcessor, LlavaForConditionalGeneration + from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ( @@ -11,11 +13,10 @@ from docling.datamodel.pipeline_options import ( HuggingFaceVlmOptions, ) from docling.models.base_model import BasePageModel +from docling.models.hf_vlm_model import HuggingFaceVlmModel from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import TimeRecorder -from transformers import AutoProcessor, LlavaForConditionalGeneration - _log = logging.getLogger(__name__) @@ -32,41 +33,45 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel): self.trust_remote_code = True self.vlm_options = vlm_options - print(self.vlm_options) if self.enabled: import torch from transformers import ( # type: ignore - LlavaForConditionalGeneration, AutoProcessor, + LlavaForConditionalGeneration, ) self.device = decide_device(accelerator_options.device) self.device = "cpu" # FIXME torch.set_num_threads(12) # Adjust the number as needed - + _log.debug(f"Available device for VLM: {self.device}") repo_cache_folder = vlm_options.repo_id.replace("/", "--") # PARAMETERS: if artifacts_path is None: - artifacts_path = self.download_models(self.vlm_options.repo_id) + # artifacts_path = self.download_models(self.vlm_options.repo_id) + artifacts_path = HuggingFaceVlmModel.download_models( + self.vlm_options.repo_id + ) elif (artifacts_path / repo_cache_folder).exists(): artifacts_path = artifacts_path / repo_cache_folder model_path = artifacts_path print(f"model: {model_path}") - self.max_new_tokens = 64 # FIXME - + self.max_new_tokens = 64 # FIXME + self.processor = AutoProcessor.from_pretrained( artifacts_path, trust_remote_code=self.trust_remote_code, ) - self.vlm_model = LlavaForConditionalGeneration.from_pretrained(artifacts_path).to(self.device) + self.vlm_model = LlavaForConditionalGeneration.from_pretrained( + artifacts_path + ).to(self.device) - + """ @staticmethod def download_models( repo_id: str, @@ -87,6 +92,7 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel): ) return Path(download_path) + """ def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] @@ -109,20 +115,22 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel): if hi_res_image.mode != "RGB": hi_res_image = hi_res_image.convert("RGB") - images = [ - hi_res_image - ] + images = [hi_res_image] prompt = "[INST]Describe the images.\n[IMG][/INST]" - - inputs = self.processor(text=prompt, images=images, return_tensors="pt", use_fast=False).to(self.device) #.to("cuda") + + inputs = self.processor( + text=prompt, images=images, return_tensors="pt", use_fast=False + ).to(self.device) # .to("cuda") generate_ids = self.vlm_model.generate( **inputs, max_new_tokens=self.max_new_tokens, - use_cache=True # Enables KV caching which can improve performance + use_cache=True, # Enables KV caching which can improve performance ) - response = self.processor.batch_decode(generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False)[0] + response = self.processor.batch_decode( + generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0] print(f"response: {response}") """ _log.debug( diff --git a/docling/models/hf_vlm_models/pixtral_12b_2409.py b/docling/models/hf_vlm_models/pixtral_12b_2409.py deleted file mode 100644 index a9afa225..00000000 --- a/docling/models/hf_vlm_models/pixtral_12b_2409.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging -import time -from collections.abc import Iterable -from pathlib import Path -from typing import Optional - -from docling.datamodel.base_models import Page, VlmPrediction -from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import ( - AcceleratorOptions, - HuggingFaceVlmOptions, -) -from docling.models.base_model import BasePageModel -from docling.utils.accelerator_utils import decide_device -from docling.utils.profiling import TimeRecorder - -_log = logging.getLogger(__name__) - - -class HuggingFaceVlmModel_pixtral_12b_2409(BasePageModel): - def __init__( - self, - enabled: bool, - artifacts_path: Optional[Path], - accelerator_options: AcceleratorOptions, - vlm_options: HuggingFaceVlmOptions, - ): - self.enabled = enabled - - self.vlm_options = vlm_options - - if self.enabled: - import torch diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index 5123fc38..b90a17dd 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -24,18 +24,16 @@ from docling.datamodel.settings import settings from docling.models.api_vlm_model import ApiVlmModel # from docling.models.hf_vlm_model import HuggingFaceVlmModel -from docling.models.hf_vlm_models.hf_vlm_mlx_model import ( - HuggingFaceMlxModel -) -from docling.models.hf_vlm_models.hf_vlm_model_LlavaForConditionalGeneration import ( - HuggingFaceVlmModel_LlavaForConditionalGeneration -) +from docling.models.hf_vlm_models.hf_vlm_mlx_model import HuggingFaceMlxModel from docling.models.hf_vlm_models.hf_vlm_model_AutoModelForCausalLM import ( HuggingFaceVlmModel_AutoModelForCausalLM, ) from docling.models.hf_vlm_models.hf_vlm_model_AutoModelForVision2Seq import ( HuggingFaceVlmModel_AutoModelForVision2Seq, ) +from docling.models.hf_vlm_models.hf_vlm_model_LlavaForConditionalGeneration import ( + HuggingFaceVlmModel_LlavaForConditionalGeneration, +) from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -124,9 +122,11 @@ class VlmPipeline(PaginatedPipeline): accelerator_options=pipeline_options.accelerator_options, vlm_options=vlm_options, ), - ] + ] else: - raise ValueError(f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}") + raise ValueError( + f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}" + ) self.enrichment_pipe = [ # Other models working on `NodeItem` elements in the DoclingDocument diff --git a/docs/examples/minimal_vlm_pipeline.py b/docs/examples/minimal_vlm_pipeline.py index eebf6699..5ab971c3 100644 --- a/docs/examples/minimal_vlm_pipeline.py +++ b/docs/examples/minimal_vlm_pipeline.py @@ -50,10 +50,10 @@ vlm_conversion_options = pixtral_vlm_conversion_options """ pixtral_vlm_conversion_options = HuggingFaceVlmOptions( - repo_id="mistral-community/pixtral-12b", - prompt="OCR this image and export it in MarkDown.", - response_format=ResponseFormat.MARKDOWN, - inference_framework=InferenceFramework.TRANSFORMERS_LlavaForConditionalGeneration, + repo_id="mistral-community/pixtral-12b", + prompt="OCR this image and export it in MarkDown.", + response_format=ResponseFormat.MARKDOWN, + inference_framework=InferenceFramework.TRANSFORMERS_LlavaForConditionalGeneration, ) vlm_conversion_options = pixtral_vlm_conversion_options