pixtral 12b runs via MLX and native transformers

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-14 07:39:20 +02:00
parent 054e01d8b3
commit f159075b67
2 changed files with 14 additions and 4 deletions

View File

@ -40,22 +40,29 @@ class HuggingFaceMlxModel(BasePageModel):
) )
repo_cache_folder = vlm_options.repo_id.replace("/", "--") repo_cache_folder = vlm_options.repo_id.replace("/", "--")
print(f"model init: {repo_cache_folder}")
self.apply_chat_template = apply_chat_template self.apply_chat_template = apply_chat_template
self.stream_generate = stream_generate self.stream_generate = stream_generate
# PARAMETERS: # PARAMETERS:
if artifacts_path is None: if artifacts_path is None:
print(f"before HuggingFaceVlmModel.download_models: {self.vlm_options.repo_id}")
# artifacts_path = self.download_models(self.vlm_options.repo_id) # artifacts_path = self.download_models(self.vlm_options.repo_id)
artifacts_path = HuggingFaceVlmModel.download_models( artifacts_path = HuggingFaceVlmModel.download_models(
self.vlm_options.repo_id self.vlm_options.repo_id, progress=True,
) )
elif (artifacts_path / repo_cache_folder).exists(): elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder artifacts_path = artifacts_path / repo_cache_folder
print(f"downloaded model: {artifacts_path}")
self.param_question = vlm_options.prompt # "Perform Layout Analysis." self.param_question = vlm_options.prompt # "Perform Layout Analysis."
## Load the model ## Load the model
print("start loading model ...")
self.vlm_model, self.processor = load(artifacts_path) self.vlm_model, self.processor = load(artifacts_path)
print("loaded model ...")
self.config = load_config(artifacts_path) self.config = load_config(artifacts_path)
""" """
@ -110,6 +117,8 @@ class HuggingFaceMlxModel(BasePageModel):
) )
start_time = time.time() start_time = time.time()
print("start generating ...")
# Call model to generate: # Call model to generate:
output = "" output = ""
for token in self.stream_generate( for token in self.stream_generate(
@ -120,6 +129,7 @@ class HuggingFaceMlxModel(BasePageModel):
max_tokens=4096, max_tokens=4096,
verbose=False, verbose=False,
): ):
print(token.text, end="", flush=True)
output += token.text output += token.text
if "</doctag>" in token.text: if "</doctag>" in token.text:
break break

View File

@ -49,6 +49,7 @@ pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
vlm_conversion_options = pixtral_vlm_conversion_options vlm_conversion_options = pixtral_vlm_conversion_options
""" """
"""
pixtral_vlm_conversion_options = HuggingFaceVlmOptions( pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="mistral-community/pixtral-12b", repo_id="mistral-community/pixtral-12b",
prompt="OCR this image and export it in MarkDown.", prompt="OCR this image and export it in MarkDown.",
@ -56,6 +57,7 @@ pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
inference_framework=InferenceFramework.TRANSFORMERS_LlavaForConditionalGeneration, inference_framework=InferenceFramework.TRANSFORMERS_LlavaForConditionalGeneration,
) )
vlm_conversion_options = pixtral_vlm_conversion_options vlm_conversion_options = pixtral_vlm_conversion_options
"""
""" """
phi_vlm_conversion_options = HuggingFaceVlmOptions( phi_vlm_conversion_options = HuggingFaceVlmOptions(
@ -68,15 +70,13 @@ phi_vlm_conversion_options = HuggingFaceVlmOptions(
vlm_conversion_options = phi_vlm_conversion_options vlm_conversion_options = phi_vlm_conversion_options
""" """
"""
pixtral_vlm_conversion_options = HuggingFaceVlmOptions( pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="mlx-community/pixtral-12b-bf16", repo_id="mlx-community/pixtral-12b-bf16",
prompt="Convert this full page to markdown. Do not miss any text and only output the bare MarkDown!", prompt="Convert this page to markdown. Do not miss any text and only output the bare MarkDown!",
response_format=ResponseFormat.MARKDOWN, response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX, inference_framework=InferenceFramework.MLX,
) )
vlm_conversion_options = pixtral_vlm_conversion_options vlm_conversion_options = pixtral_vlm_conversion_options
"""
""" """
qwen_vlm_conversion_options = HuggingFaceVlmOptions( qwen_vlm_conversion_options = HuggingFaceVlmOptions(