added the formulate_prompt

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-14 06:26:16 +02:00
parent 4c0bc61e54
commit 054e01d8b3
2 changed files with 74 additions and 65 deletions

View File

@ -44,6 +44,9 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = "cpu" # FIXME self.device = "cpu" # FIXME
self.use_cache = True
self.max_new_tokens = 64 # FIXME
_log.debug(f"Available device for VLM: {self.device}") _log.debug(f"Available device for VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--") repo_cache_folder = vlm_options.repo_id.replace("/", "--")
@ -102,29 +105,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
# Load generation config # Load generation config
self.generation_config = GenerationConfig.from_pretrained(model_path) self.generation_config = GenerationConfig.from_pretrained(model_path)
"""
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
"""
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]: ) -> Iterable[Page]:
@ -147,13 +127,8 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
hi_res_image = hi_res_image.convert("RGB") hi_res_image = hi_res_image.convert("RGB")
# Define prompt structure # Define prompt structure
user_prompt = "<|user|>" prompt = self.formulate_prompt()
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
# Part 1: Image Processing
prompt = f"{user_prompt}<|image_1|>Convert this image into MarkDown and only return the bare MarkDown!{prompt_suffix}{assistant_prompt}"
inputs = self.processor( inputs = self.processor(
text=prompt, images=hi_res_image, return_tensors="pt" text=prompt, images=hi_res_image, return_tensors="pt"
).to(self.device) ).to(self.device)
@ -162,7 +137,8 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
start_time = time.time() start_time = time.time()
generate_ids = self.vlm_model.generate( generate_ids = self.vlm_model.generate(
**inputs, **inputs,
max_new_tokens=128, max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache, # Enables KV caching which can improve performance
generation_config=self.generation_config, generation_config=self.generation_config,
num_logits_to_keep=1, num_logits_to_keep=1,
) )
@ -191,3 +167,22 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
page.predictions.vlm_response = VlmPrediction(text=response) page.predictions.vlm_response = VlmPrediction(text=response)
yield page yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.repo_id=="microsoft/Phi-4-multimodal-instruct":
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
# prompt = f"{user_prompt}<|image_1|>Convert this image into MarkDown and only return the bare MarkDown!{prompt_suffix}{assistant_prompt}"
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
else:
raise ValueError(f"No prompt template for {self.vlm_options.repo_id}")
return ""

View File

@ -44,14 +44,13 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = "cpu" # FIXME self.device = "cpu" # FIXME
torch.set_num_threads(12) # Adjust the number as needed self.use_cache = True
self.max_new_tokens = 64 # FIXME
_log.debug(f"Available device for VLM: {self.device}") _log.debug(f"Available device for VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--") repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None: if artifacts_path is None:
# artifacts_path = self.download_models(self.vlm_options.repo_id)
artifacts_path = HuggingFaceVlmModel.download_models( artifacts_path = HuggingFaceVlmModel.download_models(
self.vlm_options.repo_id self.vlm_options.repo_id
) )
@ -59,41 +58,25 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
artifacts_path = artifacts_path / repo_cache_folder artifacts_path = artifacts_path / repo_cache_folder
model_path = artifacts_path model_path = artifacts_path
print(f"model: {model_path}") _log.debug(f"model: {model_path}")
self.max_new_tokens = 64 # FIXME
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
artifacts_path, artifacts_path,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
) )
self.vlm_model = LlavaForConditionalGeneration.from_pretrained( self.vlm_model = LlavaForConditionalGeneration.from_pretrained(
artifacts_path artifacts_path,
device_map=self.device,
# torch_dtype="auto",
# quantization_config=self.param_quantization_config,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
).to(self.device) ).to(self.device)
"""
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
"""
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]: ) -> Iterable[Page]:
@ -116,22 +99,32 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
hi_res_image = hi_res_image.convert("RGB") hi_res_image = hi_res_image.convert("RGB")
images = [hi_res_image] images = [hi_res_image]
prompt = "<s>[INST]Describe the images.\n[IMG][/INST]"
# Define prompt structure
# prompt = "<s>[INST]Describe the images.\n[IMG][/INST]"
prompt = self.formulate_prompt()
inputs = self.processor( inputs = self.processor(
text=prompt, images=images, return_tensors="pt", use_fast=False text=prompt, images=images, return_tensors="pt"
).to(self.device) # .to("cuda") ).to(self.device) # .to("cuda")
# Generate response
start_time = time.time()
generate_ids = self.vlm_model.generate( generate_ids = self.vlm_model.generate(
**inputs, **inputs,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
use_cache=True, # Enables KV caching which can improve performance use_cache=self.use_cache, # Enables KV caching which can improve performance
) )
num_tokens = len(generate_ids[0])
generation_time = time.time() - start_time
response = self.processor.batch_decode( response = self.processor.batch_decode(
generate_ids, generate_ids,
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
)[0] )[0]
print(f"response: {response}")
""" """
_log.debug( _log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
@ -147,3 +140,24 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
page.predictions.vlm_response = VlmPrediction(text=response) page.predictions.vlm_response = VlmPrediction(text=response)
yield page yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.repo_id=="mistral-community/pixtral-12b":
#prompt = f"<s>[INST]{self.vlm_options.prompt}\n[IMG][/INST]"
chat = [
{
"role": "user", "content": [
{"type": "text", "content": self.vlm_options.prompt},
{"type": "image"},
]
}
]
prompt = self.processor.apply_chat_template(chat)
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
else:
raise ValueError(f"No prompt template for {self.vlm_options.repo_id}")
return ""