feat: Repetition-based StoppingCriteria for GraniteDocling (#2323)

* Experimental code for repetition detection, VLLM Streaming

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update VLLM Streaming

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update VLLM inference code, CLI and VLM specs

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix generation and decoder args for HF model

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix vllm device args

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Cleanup

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Bugfixes

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Remove streaming VLLM for the moment

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add repetition StoppingCriteria for GraniteDocling/SmolDocling

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Make GenerationStopper base class and port for MLX

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add streaming support and custom GenerationStopper support for ApiVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fixes for ApiVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fixes for ApiVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix api_image_request_streaming when GenerationStopper triggers.

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Move DocTagsRepetitionStopper to utility unit, update examples

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-09-30 15:26:09 +02:00
committed by GitHub
parent 68ae7ccf3c
commit 1e9dc43b72
15 changed files with 541 additions and 38 deletions

View File

@@ -60,7 +60,7 @@ jobs:
run: | run: |
for file in docs/examples/*.py; do for file in docs/examples/*.py; do
# Skip batch_convert.py # Skip batch_convert.py
if [[ "$(basename "$file")" =~ ^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then if [[ "$(basename "$file")" =~ ^(batch_convert|granitedocling_repetition_stopping|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then
echo "Skipping $file" echo "Skipping $file"
continue continue
fi fi

View File

@@ -78,7 +78,7 @@ class AsciiDocBackend(DeclarativeDocumentBackend):
return doc return doc
def _parse(self, doc: DoclingDocument): # noqa: C901 def _parse(self, doc: DoclingDocument):
""" """
Main function that orchestrates the parsing by yielding components: Main function that orchestrates the parsing by yielding components:
title, section headers, text, lists, and tables. title, section headers, text, lists, and tables.

View File

@@ -812,7 +812,7 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
else prev_parent else prev_parent
) )
def _handle_text_elements( # noqa: C901 def _handle_text_elements(
self, self,
element: BaseOxmlElement, element: BaseOxmlElement,
docx_obj: DocxDocument, docx_obj: DocxDocument,

View File

@@ -352,7 +352,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
return return
def _parse_element_citation(self, node: etree._Element) -> str: # noqa: C901 def _parse_element_citation(self, node: etree._Element) -> str:
citation: Citation = { citation: Citation = {
"author_names": "", "author_names": "",
"title": "", "title": "",
@@ -538,7 +538,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
return return
@staticmethod @staticmethod
def parse_table_data(element: Tag) -> Optional[TableData]: # noqa: C901 def parse_table_data(element: Tag) -> Optional[TableData]:
# TODO, see how to implement proper support for rich tables from HTML backend # TODO, see how to implement proper support for rich tables from HTML backend
nested_tables = element.find("table") nested_tables = element.find("table")
if nested_tables is not None: if nested_tables is not None:
@@ -713,7 +713,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
) )
return return
def _walk_linear( # noqa: C901 def _walk_linear(
self, doc: DoclingDocument, parent: NodeItem, node: etree._Element self, doc: DoclingDocument, parent: NodeItem, node: etree._Element
) -> str: ) -> str:
skip_tags = ["term"] skip_tags = ["term"]

View File

@@ -1523,7 +1523,7 @@ class XmlTable:
return ncols_max return ncols_max
def _parse_table(self, table: Tag) -> TableData: # noqa: C901 def _parse_table(self, table: Tag) -> TableData:
"""Parse the content of a table tag. """Parse the content of a table tag.
Args: Args:

View File

@@ -1,11 +1,13 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional, Union
from docling_core.types.doc.page import SegmentedPage from docling_core.types.doc.page import SegmentedPage
from pydantic import AnyUrl, BaseModel from pydantic import AnyUrl, BaseModel, ConfigDict
from transformers import StoppingCriteria
from typing_extensions import deprecated from typing_extensions import deprecated
from docling.datamodel.accelerator_options import AcceleratorDevice from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.models.utils.generation_utils import GenerationStopper
class BaseVlmOptions(BaseModel): class BaseVlmOptions(BaseModel):
@@ -50,6 +52,8 @@ class TransformersPromptStyle(str, Enum):
class InlineVlmOptions(BaseVlmOptions): class InlineVlmOptions(BaseVlmOptions):
model_config = ConfigDict(arbitrary_types_allowed=True)
kind: Literal["inline_model_options"] = "inline_model_options" kind: Literal["inline_model_options"] = "inline_model_options"
repo_id: str repo_id: str
@@ -72,6 +76,7 @@ class InlineVlmOptions(BaseVlmOptions):
] ]
stop_strings: List[str] = [] stop_strings: List[str] = []
custom_stopping_criteria: List[Union[StoppingCriteria, GenerationStopper]] = []
extra_generation_config: Dict[str, Any] = {} extra_generation_config: Dict[str, Any] = {}
extra_processor_kwargs: Dict[str, Any] = {} extra_processor_kwargs: Dict[str, Any] = {}
@@ -89,6 +94,8 @@ class HuggingFaceVlmOptions(InlineVlmOptions):
class ApiVlmOptions(BaseVlmOptions): class ApiVlmOptions(BaseVlmOptions):
model_config = ConfigDict(arbitrary_types_allowed=True)
kind: Literal["api_model_options"] = "api_model_options" kind: Literal["api_model_options"] = "api_model_options"
url: AnyUrl = AnyUrl( url: AnyUrl = AnyUrl(
@@ -99,3 +106,6 @@ class ApiVlmOptions(BaseVlmOptions):
timeout: float = 60 timeout: float = 60
concurrency: int = 1 concurrency: int = 1
response_format: ResponseFormat response_format: ResponseFormat
stop_strings: List[str] = []
custom_stopping_criteria: List[Union[GenerationStopper]] = []

View File

@@ -1,12 +1,18 @@
from collections.abc import Iterable from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from transformers import StoppingCriteria
from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
from docling.exceptions import OperationNotAllowed from docling.exceptions import OperationNotAllowed
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel
from docling.utils.api_image_request import api_image_request from docling.models.utils.generation_utils import GenerationStopper
from docling.utils.api_image_request import (
api_image_request,
api_image_request_streaming,
)
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
@@ -41,19 +47,43 @@ class ApiVlmModel(BasePageModel):
assert page._backend is not None assert page._backend is not None
if not page._backend.is_valid(): if not page._backend.is_valid():
return page return page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image( with TimeRecorder(conv_res, "vlm"):
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
assert hi_res_image is not None
if hi_res_image and hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
prompt = self.vlm_options.build_prompt(page.parsed_page)
if self.vlm_options.custom_stopping_criteria:
# Instantiate any GenerationStopper classes before passing to streaming
instantiated_stoppers = []
for criteria in self.vlm_options.custom_stopping_criteria:
if isinstance(criteria, GenerationStopper):
instantiated_stoppers.append(criteria)
elif isinstance(criteria, type) and issubclass(
criteria, GenerationStopper
):
instantiated_stoppers.append(criteria())
# Skip non-GenerationStopper criteria (should have been caught in validation)
# Streaming path with early abort support
page_tags = api_image_request_streaming(
image=hi_res_image,
prompt=prompt,
url=self.vlm_options.url,
timeout=self.timeout,
headers=self.vlm_options.headers,
generation_stoppers=instantiated_stoppers,
**self.params,
) )
assert hi_res_image is not None else:
if hi_res_image: # Non-streaming fallback (existing behavior)
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
prompt = self.vlm_options.build_prompt(page.parsed_page)
page_tags = api_image_request( page_tags = api_image_request(
image=hi_res_image, image=hi_res_image,
prompt=prompt, prompt=prompt,
@@ -63,10 +93,9 @@ class ApiVlmModel(BasePageModel):
**self.params, **self.params,
) )
page_tags = self.vlm_options.decode_response(page_tags) page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(text=page_tags) page.predictions.vlm_response = VlmPrediction(text=page_tags)
return page
return page
with ThreadPoolExecutor(max_workers=self.concurrency) as executor: with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
yield from executor.map(_vlm_request, page_batch) yield from executor.map(_vlm_request, page_batch)

View File

@@ -103,7 +103,7 @@ class ReadingOrderModel:
else: else:
doc.add_text(parent=doc_item, label=c_label, text=c_text, prov=c_prov) doc.add_text(parent=doc_item, label=c_label, text=c_text, prov=c_prov)
def _readingorder_elements_to_docling_doc( # noqa: C901 def _readingorder_elements_to_docling_doc(
self, self,
conv_res: ConversionResult, conv_res: ConversionResult,
ro_elements: List[ReadingOrderPageElement], ro_elements: List[ReadingOrderPageElement],

View File

@@ -0,0 +1,157 @@
import logging
import re
import sys
from abc import abstractmethod
from typing import List
from transformers import StoppingCriteria
_log = logging.getLogger(__name__)
class GenerationStopper:
"""
Base interface for stopping logic.
- should_stop(s): True to stop given the current decoded text window.
- lookback_tokens(): how many tokens should be considered (default: sys.maxsize).
"""
@abstractmethod
def should_stop(self, s: str) -> bool:
pass
def lookback_tokens(self) -> int:
return sys.maxsize
class DocTagsRepetitionStopper(GenerationStopper):
"""
Detects repetitive <tag>...<loc_x><loc_y><loc_w><loc_h>text</tag> blocks,
but only when repeats are **consecutive** and both tag & inner text are identical.
Performance:
- Heavy check runs every N calls (default 32).
- Only decodes the last LOOKBACK_TOKENS tokens per sequence (default 200).
"""
def __init__(self, *, N: int = 32, lookback_tokens: int = 200):
self.N = max(1, int(N))
self._lookback_tokens = max(1, int(lookback_tokens))
self._call_count = 0
# <tag> ... <loc_x><loc_y><loc_w><loc_h> text ... </tag>
self._PATTERN = re.compile(
r"""
<(?P<tag>[a-zA-Z0-9_]+)>\s*
(?P<prefix>.*?)?
<loc_(?P<x>\d+)><loc_(?P<y>\d+)><loc_(?P<w>\d+)><loc_(?P<h>\d+)>
(?P<text>.*?)
</(?P=tag)>
""",
re.DOTALL | re.VERBOSE,
)
# --- small helper ---
def _regular(self, vals: List[int]) -> bool:
"""3+ strictly increasing values with ~regular spacing (±20%)."""
if len(vals) < 3:
return False
diffs = [b - a for a, b in zip(vals, vals[1:])]
if any(d <= 0 for d in diffs):
return False
mean = sum(diffs) / len(diffs)
tol = 0.2 * mean
return all(abs(d - mean) <= tol for d in diffs)
def should_stop(self, s: str) -> bool:
"""
Trip only on **consecutive** runs (no other matched blocks between) of ≥3 items
with the same <tag> and identical inner text, where within that run we see:
- any exact duplicate (x,y,w,h), or
- stable X/W with regular Y progression, or
- stable Y/H with regular X progression.
"""
# Stream matches and evaluate runs on-the-fly to stay compact and fast.
prev_tag = prev_text = None
run = [] # list of (x,y,w,h)
def run_repetitive(boxes: List[tuple]) -> bool:
if len(boxes) < 3:
return False
# duplicates?
if len(set(boxes)) < len(boxes):
return True
xs, ys, ws, hs = zip(*boxes)
x_stable = all(x == xs[0] for x in xs)
y_stable = all(y == ys[0] for y in ys)
w_stable = all(w == ws[0] for w in ws)
h_stable = all(h == hs[0] for h in hs)
# horizontal (down the page): X/W stable, Y regular
if (x_stable or w_stable) and self._regular(list(ys)):
return True
# vertical (across): Y/H stable, X regular
if (y_stable or h_stable) and self._regular(list(xs)):
return True
return False
for m in self._PATTERN.finditer(s):
tag, text = m.group("tag"), m.group("text")
box = (
int(m.group("x")),
int(m.group("y")),
int(m.group("w")),
int(m.group("h")),
)
if prev_tag == tag and prev_text == text:
run.append(box) # consecutive same-tag+text
else:
# evaluate previous run before starting a new one
if run_repetitive(run):
return True
prev_tag, prev_text = tag, text
run = [box]
# check the last run
return run_repetitive(run)
class HFStoppingCriteriaWrapper(StoppingCriteria):
"""
Adapts any GenerationStopper to HuggingFace Transformers.
Decodes exactly min(seq_len, stopper.lookback_tokens()) tokens from the end.
"""
def __init__(
self,
tokenizer,
stopper: GenerationStopper,
*,
skip_special_tokens: bool = False,
):
self.tokenizer = tokenizer
self.stopper = stopper
self.skip_special_tokens = skip_special_tokens
def __call__(self, input_ids, scores, **kwargs) -> bool:
lb = max(1, int(self.stopper.lookback_tokens()))
for seq in input_ids: # (batch, seq_len)
window = seq[-lb:] # slicing handles lb > len(seq)
try:
text = self.tokenizer.decode(
window, skip_special_tokens=self.skip_special_tokens
)
except Exception as e:
_log.info(f"Decoding failed for stopping check: {e}")
continue
try:
if self.stopper.should_stop(text):
_log.info(
"HF wrapper: stopping due to TextStopper.should_stop==True"
)
return True
except Exception as e:
_log.info(f"Error in TextStopper.should_stop: {e}")
continue
return False

View File

@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import numpy as np import numpy as np
from PIL.Image import Image from PIL.Image import Image
from transformers import StoppingCriteriaList, StopStringCriteria from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
@@ -20,6 +20,10 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersPromptStyle, TransformersPromptStyle,
) )
from docling.models.base_model import BaseVlmPageModel from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.generation_utils import (
GenerationStopper,
HFStoppingCriteriaWrapper,
)
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@@ -253,17 +257,50 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
# -- Optional stopping criteria # -- Optional stopping criteria
stopping_criteria = None stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList()
# Add string-based stopping criteria
if self.vlm_options.stop_strings: if self.vlm_options.stop_strings:
stopping_criteria = StoppingCriteriaList( stopping_criteria_list.append(
[ StopStringCriteria(
StopStringCriteria( stop_strings=self.vlm_options.stop_strings,
stop_strings=self.vlm_options.stop_strings, tokenizer=self.processor.tokenizer,
tokenizer=self.processor.tokenizer, )
)
]
) )
# Add custom stopping criteria
if self.vlm_options.custom_stopping_criteria:
for criteria in self.vlm_options.custom_stopping_criteria:
# If it's a class (not an instance), determine the type and handle accordingly
if isinstance(criteria, type):
# Check if it's a GenerationStopper class
if issubclass(criteria, GenerationStopper):
# Instantiate GenerationStopper and wrap it
stopper_instance = criteria()
wrapped_criteria = HFStoppingCriteriaWrapper(
self.processor.tokenizer, stopper_instance
)
stopping_criteria_list.append(wrapped_criteria)
elif issubclass(criteria, StoppingCriteria):
# It's a StoppingCriteria class, instantiate with tokenizer
criteria_instance = criteria(self.processor.tokenizer)
stopping_criteria_list.append(criteria_instance)
elif isinstance(criteria, GenerationStopper):
# Wrap GenerationStopper instances in HFStoppingCriteriaWrapper
wrapped_criteria = HFStoppingCriteriaWrapper(
self.processor.tokenizer, criteria
)
stopping_criteria_list.append(wrapped_criteria)
else:
# If it's already an instance of StoppingCriteria, use it directly
stopping_criteria_list.append(criteria)
stopping_criteria = (
StoppingCriteriaList(stopping_criteria_list)
if stopping_criteria_list
else None
)
# -- Filter out decoder-specific keys from extra_generation_config # -- Filter out decoder-specific keys from extra_generation_config
decoder_keys = { decoder_keys = {
"skip_special_tokens", "skip_special_tokens",

View File

@@ -1,4 +1,5 @@
import logging import logging
import sys
import threading import threading
import time import time
from collections.abc import Iterable from collections.abc import Iterable
@@ -7,6 +8,7 @@ from typing import Optional, Union
import numpy as np import numpy as np
from PIL.Image import Image from PIL.Image import Image
from transformers import StoppingCriteria
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BaseVlmPageModel from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.generation_utils import GenerationStopper
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@@ -69,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
self.vlm_model, self.processor = load(artifacts_path) self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path) self.config = load_config(artifacts_path)
# Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
if self.vlm_options.custom_stopping_criteria:
for criteria in self.vlm_options.custom_stopping_criteria:
if isinstance(criteria, StoppingCriteria):
raise ValueError(
f"MLX models do not support HuggingFace StoppingCriteria instances. "
f"Found {type(criteria).__name__}. Use GenerationStopper instead."
)
elif isinstance(criteria, type) and issubclass(
criteria, StoppingCriteria
):
raise ValueError(
f"MLX models do not support HuggingFace StoppingCriteria classes. "
f"Found {criteria.__name__}. Use GenerationStopper instead."
)
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]: ) -> Iterable[Page]:
@@ -193,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1 self.processor, self.config, user_prompt, num_images=1
) )
# Stream generate with stop strings support # Stream generate with stop strings and custom stopping criteria support
start_time = time.time() start_time = time.time()
_log.debug("start generating ...") _log.debug("start generating ...")
@@ -245,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
_log.debug("Stopping generation due to stop string match") _log.debug("Stopping generation due to stop string match")
break break
# Check for custom stopping criteria (GenerationStopper instances)
if self.vlm_options.custom_stopping_criteria:
for criteria in self.vlm_options.custom_stopping_criteria:
# Handle both instances and classes of GenerationStopper
if isinstance(criteria, GenerationStopper):
stopper = criteria
elif isinstance(criteria, type) and issubclass(
criteria, GenerationStopper
):
stopper = criteria()
# Determine the text window to check based on lookback_tokens
lookback_tokens = stopper.lookback_tokens()
# Check only the last N characters worth of text
# This is a simplified approach - in practice, you might want to
# decode the last N tokens from the token list for more accuracy
text_to_check = (
output[-lookback_tokens:]
if len(output) > lookback_tokens
else output
)
try:
if stopper.should_stop(text_to_check):
_log.info(
f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
)
break
except Exception as e:
_log.warning(
f"Error in GenerationStopper.should_stop: {e}"
)
continue
else: # note: for-else idiom
continue # Only executed if the inner loop didn't break
break # Break the outer loop if any stopper triggered
generation_time = time.time() - start_time generation_time = time.time() - start_time
_log.debug( _log.debug(

View File

@@ -1,13 +1,15 @@
import base64 import base64
import json
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Dict, Optional from typing import Dict, List, Optional
import requests import requests
from PIL import Image from PIL import Image
from pydantic import AnyUrl from pydantic import AnyUrl
from docling.datamodel.base_models import OpenAiApiResponse from docling.datamodel.base_models import OpenAiApiResponse
from docling.models.utils.generation_utils import GenerationStopper
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@@ -59,3 +61,107 @@ def api_image_request(
api_resp = OpenAiApiResponse.model_validate_json(r.text) api_resp = OpenAiApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip() generated_text = api_resp.choices[0].message.content.strip()
return generated_text return generated_text
def api_image_request_streaming(
image: Image.Image,
prompt: str,
url: AnyUrl,
*,
timeout: float = 20,
headers: Optional[Dict[str, str]] = None,
generation_stoppers: List[GenerationStopper] = [],
**params,
) -> str:
"""
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
Accumulates text and calls stopper.should_stop(window) as chunks arrive.
If stopper triggers, the HTTP connection is closed to abort server-side generation.
"""
img_io = BytesIO()
image.save(img_io, "PNG")
image_b64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
},
{"type": "text", "text": prompt},
],
}
]
payload = {
"messages": messages,
"stream": True, # <-- critical for SSE streaming
**params,
}
# Debug: Log the payload to verify temperature is included
_log.debug(f"API streaming request payload: {json.dumps(payload, indent=2)}")
# Some servers require Accept: text/event-stream for SSE.
# It's safe to set it; OpenAI-compatible servers tolerate it.
hdrs = {"Accept": "text/event-stream", **(headers or {})}
# Try to force temperature via header if server ignores payload parameter
if "temperature" in params:
hdrs["X-Temperature"] = str(params["temperature"])
# Stream the HTTP response
with requests.post(
str(url), headers=hdrs, json=payload, timeout=timeout, stream=True
) as r:
if not r.ok:
_log.error(
f"Error calling the API {url} in streaming mode. Response was {r.text}"
)
r.raise_for_status()
full_text = []
for raw_line in r.iter_lines(decode_unicode=True):
if not raw_line: # keep-alives / blank lines
continue
if not raw_line.startswith("data:"):
# Some proxies inject comments; ignore anything not starting with 'data:'
continue
data = raw_line[len("data:") :].strip()
if data == "[DONE]":
break
try:
obj = json.loads(data)
except json.JSONDecodeError:
_log.debug("Skipping non-JSON SSE chunk: %r", data[:200])
continue
# OpenAI-compatible delta format
# obj["choices"][0]["delta"]["content"] may be None or missing (e.g., tool calls)
try:
delta = obj["choices"][0].get("delta") or {}
piece = delta.get("content") or ""
except (KeyError, IndexError) as e:
_log.debug("Unexpected SSE chunk shape: %s", e)
piece = ""
if piece:
full_text.append(piece)
for stopper in generation_stoppers:
# Respect stopper's lookback window. We use a simple string window which
# works with the GenerationStopper interface.
lookback = max(1, stopper.lookback_tokens())
window = "".join(full_text)[-lookback:]
if stopper.should_stop(window):
# Break out of the loop cleanly. The context manager will handle
# closing the connection when we exit the 'with' block.
# vLLM/OpenAI-compatible servers will detect the client disconnect
# and abort the request server-side.
return "".join(full_text)
return "".join(full_text)

View File

@@ -0,0 +1,108 @@
# %% [markdown]
# Experimental VLM pipeline with custom repetition stopping criteria.
#
# This script demonstrates the use of custom stopping criteria that detect
# repetitive location coordinate patterns in generated text and stop generation
# when such patterns are found.
#
# What this example does
# - Uses the GraniteDocling model with custom repetition stopping criteria injected
# - Processes a PDF document or image and monitors for repetitive coordinate patterns
# - Stops generation early when repetitive patterns are detected
# %%
import logging
from docling.datamodel import vlm_model_specs
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import VlmPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.models.utils.generation_utils import (
DocTagsRepetitionStopper,
)
from docling.pipeline.vlm_pipeline import VlmPipeline
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
# Set up logging to see when repetition stopping is triggered
logging.basicConfig(level=logging.INFO)
# Replace with a local path if preferred.
# source = "https://ibm.biz/docling-page-with-table" # Example that shows no repetitions.
source = "tests/data_scanned/old_newspaper.png" # Example that creates repetitions.
print(f"Processing document: {source}")
###### USING GRANITEDOCLING WITH CUSTOM REPETITION STOPPING
## Using standard Huggingface Transformers (most portable, slowest)
custom_vlm_options = vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.model_copy()
# Uncomment this to use MLX-accelerated version on Apple Silicon
# custom_vlm_options = vlm_model_specs.GRANITEDOCLING_MLX.model_copy() # use this for Apple Silicon
# Create custom VLM options with repetition stopping criteria
custom_vlm_options.custom_stopping_criteria = [
DocTagsRepetitionStopper(N=32)
] # check for repetitions for every 32 new tokens decoded.
pipeline_options = VlmPipelineOptions(
vlm_options=custom_vlm_options,
)
converter = DocumentConverter(
format_options={
InputFormat.IMAGE: PdfFormatOption(
pipeline_cls=VlmPipeline,
pipeline_options=pipeline_options,
),
}
)
doc = converter.convert(source=source).document
print(doc.export_to_markdown())
## Using a remote VLM inference service (for example VLLM) - uncomment to use
# custom_vlm_options = ApiVlmOptions(
# url="http://localhost:8000/v1/chat/completions", # LM studio defaults to port 1234, VLLM to 8000
# params=dict(
# model=vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.repo_id,
# max_tokens=8192,
# skip_special_tokens=True, # needed for VLLM
# ),
# headers={
# "Authorization": "Bearer YOUR_API_KEY",
# },
# prompt=vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.prompt,
# timeout=90,
# scale=2.0,
# temperature=0.0,
# response_format=ResponseFormat.DOCTAGS,
# custom_stopping_criteria=[
# DocTagsRepetitionStopper(N=1)
# ], # check for repetitions for every new chunk of the response stream
# )
# pipeline_options = VlmPipelineOptions(
# vlm_options=custom_vlm_options,
# enable_remote_services=True, # required when using a remote inference service.
# )
# converter = DocumentConverter(
# format_options={
# InputFormat.IMAGE: PdfFormatOption(
# pipeline_cls=VlmPipeline,
# pipeline_options=pipeline_options,
# ),
# }
# )
# doc = converter.convert(source=source).document
# print(doc.export_to_markdown())

View File

@@ -217,7 +217,7 @@ classmethod-decorators = [
"tests/*.py" = ["ASYNC"] # Disable ASYNC check for tests "tests/*.py" = ["ASYNC"] # Disable ASYNC check for tests
[tool.ruff.lint.mccabe] [tool.ruff.lint.mccabe]
max-complexity = 20 max-complexity = 30
# [tool.ruff.lint.isort.sections] # [tool.ruff.lint.isort.sections]
# "docling" = ["docling_core", "docling_ibm_models", "docling_parse"] # "docling" = ["docling_core", "docling_ibm_models", "docling_parse"]

BIN
tests/data_scanned/old_newspaper.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 MiB