From 054e01d8b3cf2d67b6ee92c675c15038437d00e6 Mon Sep 17 00:00:00 2001 From: Peter Staar Date: Wed, 14 May 2025 06:26:16 +0200 Subject: [PATCH] added the formulate_prompt Signed-off-by: Peter Staar --- .../hf_vlm_model_AutoModelForCausalLM.py | 57 ++++++------- ...vlm_model_LlavaForConditionalGeneration.py | 82 +++++++++++-------- 2 files changed, 74 insertions(+), 65 deletions(-) diff --git a/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py b/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py index 287da800..53af4b3c 100644 --- a/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py +++ b/docling/models/hf_vlm_models/hf_vlm_model_AutoModelForCausalLM.py @@ -44,6 +44,9 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): self.device = decide_device(accelerator_options.device) self.device = "cpu" # FIXME + self.use_cache = True + self.max_new_tokens = 64 # FIXME + _log.debug(f"Available device for VLM: {self.device}") repo_cache_folder = vlm_options.repo_id.replace("/", "--") @@ -102,29 +105,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): # Load generation config self.generation_config = GenerationConfig.from_pretrained(model_path) - """ - @staticmethod - def download_models( - repo_id: str, - local_dir: Optional[Path] = None, - force: bool = False, - progress: bool = False, - ) -> Path: - from huggingface_hub import snapshot_download - from huggingface_hub.utils import disable_progress_bars - - if not progress: - disable_progress_bars() - download_path = snapshot_download( - repo_id=repo_id, - force_download=force, - local_dir=local_dir, - # revision="v0.0.1", - ) - - return Path(download_path) - """ - def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -147,13 +127,8 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): hi_res_image = hi_res_image.convert("RGB") # Define prompt structure - user_prompt = "<|user|>" - assistant_prompt = "<|assistant|>" - prompt_suffix = "<|end|>" - - # Part 1: Image Processing - prompt = f"{user_prompt}<|image_1|>Convert this image into MarkDown and only return the bare MarkDown!{prompt_suffix}{assistant_prompt}" - + prompt = self.formulate_prompt() + inputs = self.processor( text=prompt, images=hi_res_image, return_tensors="pt" ).to(self.device) @@ -162,7 +137,8 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): start_time = time.time() generate_ids = self.vlm_model.generate( **inputs, - max_new_tokens=128, + max_new_tokens=self.max_new_tokens, + use_cache=self.use_cache, # Enables KV caching which can improve performance generation_config=self.generation_config, num_logits_to_keep=1, ) @@ -191,3 +167,22 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel): page.predictions.vlm_response = VlmPrediction(text=response) yield page + + def formulate_prompt(self) -> str: + """Formulate a prompt for the VLM.""" + if self.vlm_options.repo_id=="microsoft/Phi-4-multimodal-instruct": + user_prompt = "<|user|>" + assistant_prompt = "<|assistant|>" + prompt_suffix = "<|end|>" + + # prompt = f"{user_prompt}<|image_1|>Convert this image into MarkDown and only return the bare MarkDown!{prompt_suffix}{assistant_prompt}" + prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}" + _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}") + + return prompt + else: + raise ValueError(f"No prompt template for {self.vlm_options.repo_id}") + + + return "" + diff --git a/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py b/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py index 0f1569ef..e520adf5 100644 --- a/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py +++ b/docling/models/hf_vlm_models/hf_vlm_model_LlavaForConditionalGeneration.py @@ -44,14 +44,13 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel): self.device = decide_device(accelerator_options.device) self.device = "cpu" # FIXME - torch.set_num_threads(12) # Adjust the number as needed + self.use_cache = True + self.max_new_tokens = 64 # FIXME _log.debug(f"Available device for VLM: {self.device}") repo_cache_folder = vlm_options.repo_id.replace("/", "--") - # PARAMETERS: if artifacts_path is None: - # artifacts_path = self.download_models(self.vlm_options.repo_id) artifacts_path = HuggingFaceVlmModel.download_models( self.vlm_options.repo_id ) @@ -59,41 +58,25 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel): artifacts_path = artifacts_path / repo_cache_folder model_path = artifacts_path - print(f"model: {model_path}") - - self.max_new_tokens = 64 # FIXME + _log.debug(f"model: {model_path}") self.processor = AutoProcessor.from_pretrained( artifacts_path, trust_remote_code=self.trust_remote_code, ) self.vlm_model = LlavaForConditionalGeneration.from_pretrained( - artifacts_path + artifacts_path, + device_map=self.device, + # torch_dtype="auto", + # quantization_config=self.param_quantization_config, + _attn_implementation=( + "flash_attention_2" + if self.device.startswith("cuda") + and accelerator_options.cuda_use_flash_attention2 + else "eager" + ), ).to(self.device) - """ - @staticmethod - def download_models( - repo_id: str, - local_dir: Optional[Path] = None, - force: bool = False, - progress: bool = False, - ) -> Path: - from huggingface_hub import snapshot_download - from huggingface_hub.utils import disable_progress_bars - - if not progress: - disable_progress_bars() - download_path = snapshot_download( - repo_id=repo_id, - force_download=force, - local_dir=local_dir, - # revision="v0.0.1", - ) - - return Path(download_path) - """ - def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -116,22 +99,32 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel): hi_res_image = hi_res_image.convert("RGB") images = [hi_res_image] - prompt = "[INST]Describe the images.\n[IMG][/INST]" + # Define prompt structure + # prompt = "[INST]Describe the images.\n[IMG][/INST]" + prompt = self.formulate_prompt() + inputs = self.processor( - text=prompt, images=images, return_tensors="pt", use_fast=False + text=prompt, images=images, return_tensors="pt" ).to(self.device) # .to("cuda") + + # Generate response + start_time = time.time() generate_ids = self.vlm_model.generate( **inputs, max_new_tokens=self.max_new_tokens, - use_cache=True, # Enables KV caching which can improve performance + use_cache=self.use_cache, # Enables KV caching which can improve performance ) + + num_tokens = len(generate_ids[0]) + generation_time = time.time() - start_time + response = self.processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] - print(f"response: {response}") + """ _log.debug( f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." @@ -147,3 +140,24 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel): page.predictions.vlm_response = VlmPrediction(text=response) yield page + + def formulate_prompt(self) -> str: + """Formulate a prompt for the VLM.""" + if self.vlm_options.repo_id=="mistral-community/pixtral-12b": + #prompt = f"[INST]{self.vlm_options.prompt}\n[IMG][/INST]" + chat = [ + { + "role": "user", "content": [ + {"type": "text", "content": self.vlm_options.prompt}, + {"type": "image"}, + ] + } + ] + prompt = self.processor.apply_chat_template(chat) + _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}") + + return prompt + else: + raise ValueError(f"No prompt template for {self.vlm_options.repo_id}") + + return ""