pixtral 12b runs via MLX and native transformers

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-14 07:39:20 +02:00
parent 054e01d8b3
commit f159075b67
2 changed files with 14 additions and 4 deletions

View File

@ -40,22 +40,29 @@ class HuggingFaceMlxModel(BasePageModel):
)
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
print(f"model init: {repo_cache_folder}")
self.apply_chat_template = apply_chat_template
self.stream_generate = stream_generate
# PARAMETERS:
if artifacts_path is None:
print(f"before HuggingFaceVlmModel.download_models: {self.vlm_options.repo_id}")
# artifacts_path = self.download_models(self.vlm_options.repo_id)
artifacts_path = HuggingFaceVlmModel.download_models(
self.vlm_options.repo_id
self.vlm_options.repo_id, progress=True,
)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
print(f"downloaded model: {artifacts_path}")
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
## Load the model
print("start loading model ...")
self.vlm_model, self.processor = load(artifacts_path)
print("loaded model ...")
self.config = load_config(artifacts_path)
"""
@ -110,6 +117,8 @@ class HuggingFaceMlxModel(BasePageModel):
)
start_time = time.time()
print("start generating ...")
# Call model to generate:
output = ""
for token in self.stream_generate(
@ -120,6 +129,7 @@ class HuggingFaceMlxModel(BasePageModel):
max_tokens=4096,
verbose=False,
):
print(token.text, end="", flush=True)
output += token.text
if "</doctag>" in token.text:
break

View File

@ -49,6 +49,7 @@ pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
vlm_conversion_options = pixtral_vlm_conversion_options
"""
"""
pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="mistral-community/pixtral-12b",
prompt="OCR this image and export it in MarkDown.",
@ -56,6 +57,7 @@ pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
inference_framework=InferenceFramework.TRANSFORMERS_LlavaForConditionalGeneration,
)
vlm_conversion_options = pixtral_vlm_conversion_options
"""
"""
phi_vlm_conversion_options = HuggingFaceVlmOptions(
@ -68,15 +70,13 @@ phi_vlm_conversion_options = HuggingFaceVlmOptions(
vlm_conversion_options = phi_vlm_conversion_options
"""
"""
pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="mlx-community/pixtral-12b-bf16",
prompt="Convert this full page to markdown. Do not miss any text and only output the bare MarkDown!",
prompt="Convert this page to markdown. Do not miss any text and only output the bare MarkDown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX,
)
vlm_conversion_options = pixtral_vlm_conversion_options
"""
"""
qwen_vlm_conversion_options = HuggingFaceVlmOptions(