diff --git a/docling/models/table_structure_model_vlm.py b/docling/models/table_structure_model_vlm.py new file mode 100644 index 00000000..c6a11035 --- /dev/null +++ b/docling/models/table_structure_model_vlm.py @@ -0,0 +1,214 @@ +import logging +import time +from collections.abc import Iterable +from pathlib import Path +from typing import List, Optional + +from docling_core.types.doc import DocItemLabel +from docling_core.types.doc.utils import parse_otsl_table_content +from PIL import Image + +from docling.datamodel.accelerator_options import AcceleratorOptions +from docling.datamodel.base_models import ( + Cluster, + Page, + Table, + TableStructurePrediction, + VlmPredictionToken, +) +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import ( + TableStructureOptions, +) +from docling.models.base_model import BasePageModel +from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin +from docling.utils.profiling import TimeRecorder + +_log = logging.getLogger(__name__) + + +class TableStructureModelVlmMlx(BasePageModel, HuggingFaceModelDownloadMixin): + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: TableStructureOptions, + accelerator_options: AcceleratorOptions, + ): + self.options = options + model_repo_id = "ds4sd/granite-docling-258m-2-9-2025-v2-mlx-bf16" + + self.max_tokens = 4096 + self.temperature = 0 + self.stop_strings = ["", ""] + + self.enabled = enabled + if self.enabled: + try: + from mlx_vlm import generate, load, stream_generate # type: ignore + from mlx_vlm.prompt_utils import apply_chat_template # type: ignore + from mlx_vlm.utils import load_config # type: ignore + except ImportError: + raise ImportError( + "mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models." + ) + + repo_cache_folder = model_repo_id.replace("/", "--") + + self.apply_chat_template = apply_chat_template + self.stream_generate = stream_generate + + # PARAMETERS: + if artifacts_path is None: + artifacts_path = self.download_models( + model_repo_id, + ) + elif (artifacts_path / repo_cache_folder).exists(): + artifacts_path = artifacts_path / repo_cache_folder + + ## Load the model + self.vlm_model, self.processor = load(artifacts_path) + self.config = load_config(artifacts_path) + + self.scale = 2.0 # Scale up table input images to 144 dpi + + def _predict_images(self, image_batch: Iterable[Image.Image]) -> Iterable[str]: + user_prompt = "Convert table to OTSL." + + # Use the MLX chat template approach like in the __call__ method + formatted_prompt = self.apply_chat_template( + self.processor, self.config, user_prompt, num_images=1 + ) + + for image in image_batch: + # Stream generate with stop strings support + start_time = time.time() + _log.debug("start generating ...") + + tokens: list[VlmPredictionToken] = [] + output = "" + + # Use stream_generate for proper stop string handling + for token in self.stream_generate( + self.vlm_model, + self.processor, + formatted_prompt, + [image], # MLX stream_generate expects list of images + max_tokens=self.max_tokens, + verbose=False, + temp=self.temperature, + ): + # Collect token information + if len(token.logprobs.shape) == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[token.token], + ) + ) + elif len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1: + tokens.append( + VlmPredictionToken( + text=token.text, + token=token.token, + logprob=token.logprobs[0, token.token], + ) + ) + else: + _log.warning( + f"incompatible shape for logprobs: {token.logprobs.shape}" + ) + + output += token.text + + # Check for any configured stop strings + if self.stop_strings: + if any(stop_str in output for stop_str in self.stop_strings): + _log.debug("Stopping generation due to stop string match") + break + + generation_time = time.time() - start_time + + _log.debug( + f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)." + ) + + yield output + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + # Convert to list to allow multiple iterations + pages = list(page_batch) + + # Separate valid and invalid pages + table_images: List[Image.Image] = [] + table_clusters: List[Cluster] = [] + pages_to_tables: List[List[int]] = [] + + tbl_ix = 0 + for page in pages: + assert page._backend is not None + if not page._backend.is_valid(): + pages_to_tables.append([]) + continue + + table_indexes = [] + assert page.predictions.layout is not None + for cluster in page.predictions.layout.clusters: + if cluster.label not in { + DocItemLabel.TABLE, + DocItemLabel.DOCUMENT_INDEX, + }: + continue + + table_image = page.get_image(scale=self.scale, cropbox=cluster.bbox) + assert table_image is not None + + table_clusters.append(cluster) + table_images.append(table_image) + + table_indexes.append(tbl_ix) + tbl_ix += 1 + + pages_to_tables.append(table_indexes) + + assert len(pages) == len(pages_to_tables) + + # Process all valid pages with batch prediction + batch_predictions = [] + if table_images: + with TimeRecorder(conv_res, "table_structure"): + batch_predictions = list(self._predict_images(table_images)) + assert len(batch_predictions) == len(table_images) + + for page, page_tables_map in zip(pages, pages_to_tables): + if not page_tables_map: + yield page + + page.predictions.tablestructure = TableStructurePrediction() # dummy + + for tbl_ix in page_tables_map: + otsl_seq = batch_predictions[tbl_ix] + table_cluster = table_clusters[tbl_ix] + + print(f"{otsl_seq=}") + table_data = parse_otsl_table_content(otsl_seq) + print(f"{table_data.num_rows=}") + print(f"{table_data.num_cols=}") + + tbl = Table( + otsl_seq=[otsl_seq], + table_cells=table_data.table_cells, + num_rows=table_data.num_rows, + num_cols=table_data.num_cols, + id=table_cluster.id, + page_no=page.page_no, + cluster=table_cluster, + label=table_cluster.label, + ) + + page.predictions.tablestructure.table_map[table_cluster.id] = tbl + + yield page diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index c04ddca9..8a58f85b 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -29,6 +29,7 @@ from docling.models.page_preprocessing_model import ( from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.table_structure_model import TableStructureModel +from docling.models.table_structure_model_vlm import TableStructureModelVlmMlx from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.model_downloader import download_models from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -81,7 +82,8 @@ class StandardPdfPipeline(PaginatedPipeline): options=pipeline_options.layout_options, ), # Table structure model - TableStructureModel( + # TableStructureModel( + TableStructureModelVlmMlx( enabled=pipeline_options.do_table_structure, artifacts_path=artifacts_path, options=pipeline_options.table_structure_options,