diff --git a/docling/models/hf_mlx_model.py b/docling/models/hf_mlx_model.py new file mode 100644 index 00000000..1c4577ac --- /dev/null +++ b/docling/models/hf_mlx_model.py @@ -0,0 +1,132 @@ +import logging +import time +from pathlib import Path +from typing import Iterable, List, Optional + +from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import ( + AcceleratorDevice, + AcceleratorOptions, + HuggingFaceVlmOptions, +) +from docling.datamodel.settings import settings +from docling.models.base_model import BasePageModel +from docling.utils.accelerator_utils import decide_device +from docling.utils.profiling import TimeRecorder + +_log = logging.getLogger(__name__) + + +class HuggingFaceMlxModel(BasePageModel): + + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + accelerator_options: AcceleratorOptions, + vlm_options: HuggingFaceVlmOptions, + ): + self.enabled = enabled + + self.vlm_options = vlm_options + + if self.enabled: + from mlx_vlm import generate, load # type: ignore + from mlx_vlm.prompt_utils import apply_chat_template # type: ignore + from mlx_vlm.utils import load_config, stream_generate # type: ignore + + repo_cache_folder = vlm_options.repo_id.replace("/", "--") + self.apply_chat_template = apply_chat_template + self.stream_generate = stream_generate + + # PARAMETERS: + if artifacts_path is None: + artifacts_path = self.download_models(self.vlm_options.repo_id) + elif (artifacts_path / repo_cache_folder).exists(): + artifacts_path = artifacts_path / repo_cache_folder + + self.param_question = vlm_options.prompt # "Perform Layout Analysis." + + ## Load the model + self.vlm_model, self.processor = load(artifacts_path) + self.config = load_config(artifacts_path) + + @staticmethod + def download_models( + repo_id: str, + local_dir: Optional[Path] = None, + force: bool = False, + progress: bool = False, + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + if not progress: + disable_progress_bars() + download_path = snapshot_download( + repo_id=repo_id, + force_download=force, + local_dir=local_dir, + # revision="v0.0.1", + ) + + return Path(download_path) + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + for page in page_batch: + assert page._backend is not None + if not page._backend.is_valid(): + yield page + else: + with TimeRecorder(conv_res, "vlm"): + assert page.size is not None + + hi_res_image = page.get_image(scale=2.0) # 144dpi + # hi_res_image = page.get_image(scale=1.0) # 72dpi + + if hi_res_image is not None: + im_width, im_height = hi_res_image.size + + # populate page_tags with predicted doc tags + page_tags = "" + + if hi_res_image: + if hi_res_image.mode != "RGB": + hi_res_image = hi_res_image.convert("RGB") + + prompt = self.apply_chat_template( + self.processor, self.config, self.param_question, num_images=1 + ) + + start_time = time.time() + # Call model to generate: + output = "" + for token in self.stream_generate( + self.vlm_model, + self.processor, + prompt, + [hi_res_image], + max_tokens=4096, + verbose=False, + ): + output += token.text + print(token.text, end="") + if "" in token.text: + break + + generation_time = time.time() - start_time + page_tags = output + + # inference_time = time.time() - start_time + # tokens_per_second = num_tokens / generation_time + # print("") + # print(f"Page Inference Time: {inference_time:.2f} seconds") + # print(f"Total tokens on page: {num_tokens:.2f}") + # print(f"Tokens/sec: {tokens_per_second:.2f}") + # print("") + page.predictions.vlm_response = VlmPrediction(text=page_tags) + + yield page