feat: Repetition-based StoppingCriteria for GraniteDocling (#2323)

* Experimental code for repetition detection, VLLM Streaming

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update VLLM Streaming

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update VLLM inference code, CLI and VLM specs

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix generation and decoder args for HF model

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix vllm device args

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Cleanup

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Bugfixes

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Remove streaming VLLM for the moment

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add repetition StoppingCriteria for GraniteDocling/SmolDocling

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Make GenerationStopper base class and port for MLX

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add streaming support and custom GenerationStopper support for ApiVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fixes for ApiVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fixes for ApiVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix api_image_request_streaming when GenerationStopper triggers.

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Move DocTagsRepetitionStopper to utility unit, update examples

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-09-30 15:26:09 +02:00
committed by GitHub
parent 68ae7ccf3c
commit 1e9dc43b72
15 changed files with 541 additions and 38 deletions

View File

@@ -1,4 +1,5 @@
import logging
import sys
import threading
import time
from collections.abc import Iterable
@@ -7,6 +8,7 @@ from typing import Optional, Union
import numpy as np
from PIL.Image import Image
from transformers import StoppingCriteria
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.generation_utils import GenerationStopper
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -69,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)
# Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
if self.vlm_options.custom_stopping_criteria:
for criteria in self.vlm_options.custom_stopping_criteria:
if isinstance(criteria, StoppingCriteria):
raise ValueError(
f"MLX models do not support HuggingFace StoppingCriteria instances. "
f"Found {type(criteria).__name__}. Use GenerationStopper instead."
)
elif isinstance(criteria, type) and issubclass(
criteria, StoppingCriteria
):
raise ValueError(
f"MLX models do not support HuggingFace StoppingCriteria classes. "
f"Found {criteria.__name__}. Use GenerationStopper instead."
)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@@ -193,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1
)
# Stream generate with stop strings support
# Stream generate with stop strings and custom stopping criteria support
start_time = time.time()
_log.debug("start generating ...")
@@ -245,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
_log.debug("Stopping generation due to stop string match")
break
# Check for custom stopping criteria (GenerationStopper instances)
if self.vlm_options.custom_stopping_criteria:
for criteria in self.vlm_options.custom_stopping_criteria:
# Handle both instances and classes of GenerationStopper
if isinstance(criteria, GenerationStopper):
stopper = criteria
elif isinstance(criteria, type) and issubclass(
criteria, GenerationStopper
):
stopper = criteria()
# Determine the text window to check based on lookback_tokens
lookback_tokens = stopper.lookback_tokens()
# Check only the last N characters worth of text
# This is a simplified approach - in practice, you might want to
# decode the last N tokens from the token list for more accuracy
text_to_check = (
output[-lookback_tokens:]
if len(output) > lookback_tokens
else output
)
try:
if stopper.should_stop(text_to_check):
_log.info(
f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
)
break
except Exception as e:
_log.warning(
f"Error in GenerationStopper.should_stop: {e}"
)
continue
else: # note: for-else idiom
continue # Only executed if the inner loop didn't break
break # Break the outer loop if any stopper triggered
generation_time = time.time() - start_time
_log.debug(