mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 20:58:11 +00:00
feat: Repetition-based StoppingCriteria for GraniteDocling (#2323)
* Experimental code for repetition detection, VLLM Streaming Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update VLLM Streaming Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update VLLM inference code, CLI and VLM specs Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix generation and decoder args for HF model Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix vllm device args Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Cleanup Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Bugfixes Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Remove streaming VLLM for the moment Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add repetition StoppingCriteria for GraniteDocling/SmolDocling Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Make GenerationStopper base class and port for MLX Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add streaming support and custom GenerationStopper support for ApiVlmModel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fixes for ApiVlmModel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fixes for ApiVlmModel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix api_image_request_streaming when GenerationStopper triggers. Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Move DocTagsRepetitionStopper to utility unit, update examples Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
@@ -7,6 +8,7 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BaseVlmPageModel
|
||||
from docling.models.utils.generation_utils import GenerationStopper
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@@ -69,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
self.config = load_config(artifacts_path)
|
||||
|
||||
# Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
|
||||
if self.vlm_options.custom_stopping_criteria:
|
||||
for criteria in self.vlm_options.custom_stopping_criteria:
|
||||
if isinstance(criteria, StoppingCriteria):
|
||||
raise ValueError(
|
||||
f"MLX models do not support HuggingFace StoppingCriteria instances. "
|
||||
f"Found {type(criteria).__name__}. Use GenerationStopper instead."
|
||||
)
|
||||
elif isinstance(criteria, type) and issubclass(
|
||||
criteria, StoppingCriteria
|
||||
):
|
||||
raise ValueError(
|
||||
f"MLX models do not support HuggingFace StoppingCriteria classes. "
|
||||
f"Found {criteria.__name__}. Use GenerationStopper instead."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@@ -193,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
|
||||
# Stream generate with stop strings support
|
||||
# Stream generate with stop strings and custom stopping criteria support
|
||||
start_time = time.time()
|
||||
_log.debug("start generating ...")
|
||||
|
||||
@@ -245,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
_log.debug("Stopping generation due to stop string match")
|
||||
break
|
||||
|
||||
# Check for custom stopping criteria (GenerationStopper instances)
|
||||
if self.vlm_options.custom_stopping_criteria:
|
||||
for criteria in self.vlm_options.custom_stopping_criteria:
|
||||
# Handle both instances and classes of GenerationStopper
|
||||
if isinstance(criteria, GenerationStopper):
|
||||
stopper = criteria
|
||||
elif isinstance(criteria, type) and issubclass(
|
||||
criteria, GenerationStopper
|
||||
):
|
||||
stopper = criteria()
|
||||
|
||||
# Determine the text window to check based on lookback_tokens
|
||||
lookback_tokens = stopper.lookback_tokens()
|
||||
# Check only the last N characters worth of text
|
||||
# This is a simplified approach - in practice, you might want to
|
||||
# decode the last N tokens from the token list for more accuracy
|
||||
text_to_check = (
|
||||
output[-lookback_tokens:]
|
||||
if len(output) > lookback_tokens
|
||||
else output
|
||||
)
|
||||
|
||||
try:
|
||||
if stopper.should_stop(text_to_check):
|
||||
_log.info(
|
||||
f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
_log.warning(
|
||||
f"Error in GenerationStopper.should_stop: {e}"
|
||||
)
|
||||
continue
|
||||
else: # note: for-else idiom
|
||||
continue # Only executed if the inner loop didn't break
|
||||
break # Break the outer loop if any stopper triggered
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
_log.debug(
|
||||
|
||||
Reference in New Issue
Block a user