streamlining all code

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-16 16:27:27 +02:00
parent 661f7c9780
commit d5b6c871cf
7 changed files with 55 additions and 36 deletions

View File

@ -84,6 +84,7 @@ smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
response_format=ResponseFormat.DOCTAGS, response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX, inference_framework=InferenceFramework.MLX,
scale=2.0, scale=2.0,
temperature=0.0,
) )
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions( smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
@ -92,6 +93,7 @@ smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
response_format=ResponseFormat.DOCTAGS, response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq, inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
scale=2.0, scale=2.0,
temperature=0.0,
) )
# GraniteVision # GraniteVision
@ -101,6 +103,7 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq, inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
scale=2.0, scale=2.0,
temperature=0.0,
) )
granite_vision_vlm_ollama_conversion_options = ApiVlmOptions( granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
@ -110,6 +113,7 @@ granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
scale=1.0, scale=1.0,
timeout=120, timeout=120,
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
temperature=0.0,
) )
# Pixtral # Pixtral
@ -119,6 +123,7 @@ pixtral_12b_vlm_conversion_options = HuggingFaceVlmOptions(
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_LlavaForConditionalGeneration, inference_framework=InferenceFramework.TRANSFORMERS_LlavaForConditionalGeneration,
scale=2.0, scale=2.0,
temperature=0.0,
) )
pixtral_12b_vlm_mlx_conversion_options = HuggingFaceVlmOptions( pixtral_12b_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
@ -127,6 +132,7 @@ pixtral_12b_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX, inference_framework=InferenceFramework.MLX,
scale=2.0, scale=2.0,
temperature=0.0,
) )
# Phi4 # Phi4
@ -135,6 +141,8 @@ phi_vlm_conversion_options = HuggingFaceVlmOptions(
prompt="Convert this page to MarkDown. Do not miss any text and only output the bare MarkDown", prompt="Convert this page to MarkDown. Do not miss any text and only output the bare MarkDown",
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForCausalLM, inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForCausalLM,
scale=2.0,
temperature=0.0,
) )
# Qwen # Qwen
@ -143,4 +151,6 @@ qwen25_vl_3b_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
prompt="Convert this page to markdown. Do not miss any text and only output the bare MarkDown!", prompt="Convert this page to markdown. Do not miss any text and only output the bare MarkDown!",
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX, inference_framework=InferenceFramework.MLX,
scale=2.0,
temperature=0.0,
) )

View File

@ -6,6 +6,17 @@ _log = logging.getLogger(__name__)
class HuggingFaceVlmModel: class HuggingFaceVlmModel:
@staticmethod
def map_device_to_cpu_if_mlx(device: str) -> str:
if device == "mps":
_log.warning(
"Mapping mlx to cpu for AutoModelForCausalLM, use MLX framework!"
)
return "cpu"
return device
@staticmethod @staticmethod
def download_models( def download_models(
repo_id: str, repo_id: str,

View File

@ -29,7 +29,8 @@ class HuggingFaceMlxModel(BasePageModel):
self.vlm_options = vlm_options self.vlm_options = vlm_options
self.max_tokens = vlm_options.max_new_tokens self.max_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
if self.enabled: if self.enabled:
try: try:
from mlx_vlm import generate, load # type: ignore from mlx_vlm import generate, load # type: ignore
@ -103,8 +104,9 @@ class HuggingFaceMlxModel(BasePageModel):
self.processor, self.processor,
prompt, prompt,
[hi_res_image], [hi_res_image],
max_tokens=4096, max_tokens=self.max_tokens,
verbose=False, verbose=False,
temp=self.temperature,
): ):
if len(token.logprobs.shape) == 1: if len(token.logprobs.shape) == 1:
tokens.append( tokens.append(

View File

@ -42,19 +42,13 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
) )
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = HuggingFaceVlmMode.map_device_to_cpu_if_mlx(self.device)
if self.device == "mps": _log.debug(f"Available device for VLM: {self.device}")
_log.warning(
"Mapping mlx to cpu for AutoModelForCausalLM, use MLX framework!"
)
self.device = "cpu"
print("device: ", self.device)
self.use_cache = vlm_options.use_kv_cache self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
_log.debug(f"Available device for VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--") repo_cache_folder = vlm_options.repo_id.replace("/", "--")
if artifacts_path is None: if artifacts_path is None:
@ -126,12 +120,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
if hi_res_image is not None: if hi_res_image is not None:
im_width, im_height = hi_res_image.size im_width, im_height = hi_res_image.size
"""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
"""
# Define prompt structure # Define prompt structure
prompt = self.formulate_prompt() prompt = self.formulate_prompt()
@ -147,9 +135,9 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
**inputs, **inputs,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache, # Enables KV caching which can improve performance use_cache=self.use_cache, # Enables KV caching which can improve performance
temperature=self.temperature,
generation_config=self.generation_config, generation_config=self.generation_config,
num_logits_to_keep=1, num_logits_to_keep=1,
# temperature=0.0,
) )
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :] generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
@ -162,8 +150,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
)[0] )[0]
#_log.debug( _log.debug(
print(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
) )
page.predictions.vlm_response = VlmPrediction(text=response, generation_time=generation_time) page.predictions.vlm_response = VlmPrediction(text=response, generation_time=generation_time)

View File

@ -39,8 +39,14 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
) )
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = HuggingFaceVlmMode.map_device_to_cpu_if_mlx(self.device)
_log.debug(f"Available device for HuggingFace VLM: {self.device}") _log.debug(f"Available device for HuggingFace VLM: {self.device}")
self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
repo_cache_folder = vlm_options.repo_id.replace("/", "--") repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS: # PARAMETERS:
@ -111,10 +117,12 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
# populate page_tags with predicted doc tags # populate page_tags with predicted doc tags
page_tags = "" page_tags = ""
"""
if hi_res_image: if hi_res_image:
if hi_res_image.mode != "RGB": if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB") hi_res_image = hi_res_image.convert("RGB")
"""
# Define prompt structure # Define prompt structure
prompt = self.formulate_prompt() prompt = self.formulate_prompt()
@ -126,7 +134,10 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
start_time = time.time() start_time = time.time()
# Call model to generate: # Call model to generate:
generated_ids = self.vlm_model.generate( generated_ids = self.vlm_model.generate(
**inputs, max_new_tokens=4096, use_cache=True **inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
) )
generation_time = time.time() - start_time generation_time = time.time() - start_time

View File

@ -39,16 +39,12 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
) )
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = HuggingFaceVlmMode.map_device_to_cpu_if_mlx(self.device)
if self.device == "mlx":
_log.warning(
"Mapping mlx to cpu for LlavaForConditionalGeneration, use MLX framework!"
)
self.device = "cpu"
self.use_cache = vlm_options.use_kv_cache self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
_log.debug(f"Available device for VLM: {self.device}") _log.debug(f"Available device for VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--") repo_cache_folder = vlm_options.repo_id.replace("/", "--")
@ -93,10 +89,12 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
if hi_res_image is not None: if hi_res_image is not None:
im_width, im_height = hi_res_image.size im_width, im_height = hi_res_image.size
"""
if hi_res_image: if hi_res_image:
if hi_res_image.mode != "RGB": if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB") hi_res_image = hi_res_image.convert("RGB")
"""
images = [hi_res_image] images = [hi_res_image]
# Define prompt structure # Define prompt structure
@ -112,9 +110,10 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
**inputs, **inputs,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache, # Enables KV caching which can improve performance use_cache=self.use_cache, # Enables KV caching which can improve performance
temperature=self.temperature,
) )
num_tokens = len(generate_ids[0]) #num_tokens = len(generate_ids[0])
generation_time = time.time() - start_time generation_time = time.time() - start_time
response = self.processor.batch_decode( response = self.processor.batch_decode(
@ -125,7 +124,7 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
page.predictions.vlm_response = VlmPrediction( page.predictions.vlm_response = VlmPrediction(
text=response, text=response,
generated_tokens=num_tokens, #generated_tokens=num_tokens,
generation_time=generation_time, generation_time=generation_time,
) )
@ -134,7 +133,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
def formulate_prompt(self) -> str: def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM.""" """Formulate a prompt for the VLM."""
if self.vlm_options.repo_id == "mistral-community/pixtral-12b": if self.vlm_options.repo_id == "mistral-community/pixtral-12b":
# prompt = f"<s>[INST]{self.vlm_options.prompt}\n[IMG][/INST]"
chat = [ chat = [
{ {
"role": "user", "role": "user",

View File

@ -187,9 +187,9 @@ if __name__ == "__main__":
rows = [] rows = []
for vlm_options in [ for vlm_options in [
# smoldocling_vlm_conversion_options, \ # smoldocling_vlm_conversion_options, \
# smoldocling_vlm_mlx_conversion_options, \ smoldocling_vlm_mlx_conversion_options, \
# granite_vision_vlm_conversion_options, \ # granite_vision_vlm_conversion_options, \
phi_vlm_conversion_options, \ # phi_vlm_conversion_options, \
# qwen25_vl_3b_vlm_mlx_conversion_options, \ # qwen25_vl_3b_vlm_mlx_conversion_options, \
# pixtral_12b_vlm_mlx_conversion_options, # pixtral_12b_vlm_mlx_conversion_options,
# pixtral_12b_vlm_conversion_options, # pixtral_12b_vlm_conversion_options,