mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat: Repetition-based StoppingCriteria for GraniteDocling (#2323)
* Experimental code for repetition detection, VLLM Streaming Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update VLLM Streaming Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update VLLM inference code, CLI and VLM specs Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix generation and decoder args for HF model Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix vllm device args Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Cleanup Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Bugfixes Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Remove streaming VLLM for the moment Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add repetition StoppingCriteria for GraniteDocling/SmolDocling Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Make GenerationStopper base class and port for MLX Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add streaming support and custom GenerationStopper support for ApiVlmModel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fixes for ApiVlmModel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fixes for ApiVlmModel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix api_image_request_streaming when GenerationStopper triggers. Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Move DocTagsRepetitionStopper to utility unit, update examples Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
2
.github/workflows/checks.yml
vendored
2
.github/workflows/checks.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
run: |
|
||||
for file in docs/examples/*.py; do
|
||||
# Skip batch_convert.py
|
||||
if [[ "$(basename "$file")" =~ ^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then
|
||||
if [[ "$(basename "$file")" =~ ^(batch_convert|granitedocling_repetition_stopping|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then
|
||||
echo "Skipping $file"
|
||||
continue
|
||||
fi
|
||||
|
||||
@@ -78,7 +78,7 @@ class AsciiDocBackend(DeclarativeDocumentBackend):
|
||||
|
||||
return doc
|
||||
|
||||
def _parse(self, doc: DoclingDocument): # noqa: C901
|
||||
def _parse(self, doc: DoclingDocument):
|
||||
"""
|
||||
Main function that orchestrates the parsing by yielding components:
|
||||
title, section headers, text, lists, and tables.
|
||||
|
||||
@@ -812,7 +812,7 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
|
||||
else prev_parent
|
||||
)
|
||||
|
||||
def _handle_text_elements( # noqa: C901
|
||||
def _handle_text_elements(
|
||||
self,
|
||||
element: BaseOxmlElement,
|
||||
docx_obj: DocxDocument,
|
||||
|
||||
@@ -352,7 +352,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
||||
|
||||
return
|
||||
|
||||
def _parse_element_citation(self, node: etree._Element) -> str: # noqa: C901
|
||||
def _parse_element_citation(self, node: etree._Element) -> str:
|
||||
citation: Citation = {
|
||||
"author_names": "",
|
||||
"title": "",
|
||||
@@ -538,7 +538,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def parse_table_data(element: Tag) -> Optional[TableData]: # noqa: C901
|
||||
def parse_table_data(element: Tag) -> Optional[TableData]:
|
||||
# TODO, see how to implement proper support for rich tables from HTML backend
|
||||
nested_tables = element.find("table")
|
||||
if nested_tables is not None:
|
||||
@@ -713,7 +713,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
||||
)
|
||||
return
|
||||
|
||||
def _walk_linear( # noqa: C901
|
||||
def _walk_linear(
|
||||
self, doc: DoclingDocument, parent: NodeItem, node: etree._Element
|
||||
) -> str:
|
||||
skip_tags = ["term"]
|
||||
|
||||
@@ -1523,7 +1523,7 @@ class XmlTable:
|
||||
|
||||
return ncols_max
|
||||
|
||||
def _parse_table(self, table: Tag) -> TableData: # noqa: C901
|
||||
def _parse_table(self, table: Tag) -> TableData:
|
||||
"""Parse the content of a table tag.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from docling_core.types.doc.page import SegmentedPage
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
from pydantic import AnyUrl, BaseModel, ConfigDict
|
||||
from transformers import StoppingCriteria
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||
from docling.models.utils.generation_utils import GenerationStopper
|
||||
|
||||
|
||||
class BaseVlmOptions(BaseModel):
|
||||
@@ -50,6 +52,8 @@ class TransformersPromptStyle(str, Enum):
|
||||
|
||||
|
||||
class InlineVlmOptions(BaseVlmOptions):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
kind: Literal["inline_model_options"] = "inline_model_options"
|
||||
|
||||
repo_id: str
|
||||
@@ -72,6 +76,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
||||
]
|
||||
|
||||
stop_strings: List[str] = []
|
||||
custom_stopping_criteria: List[Union[StoppingCriteria, GenerationStopper]] = []
|
||||
extra_generation_config: Dict[str, Any] = {}
|
||||
extra_processor_kwargs: Dict[str, Any] = {}
|
||||
|
||||
@@ -89,6 +94,8 @@ class HuggingFaceVlmOptions(InlineVlmOptions):
|
||||
|
||||
|
||||
class ApiVlmOptions(BaseVlmOptions):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
kind: Literal["api_model_options"] = "api_model_options"
|
||||
|
||||
url: AnyUrl = AnyUrl(
|
||||
@@ -99,3 +106,6 @@ class ApiVlmOptions(BaseVlmOptions):
|
||||
timeout: float = 60
|
||||
concurrency: int = 1
|
||||
response_format: ResponseFormat
|
||||
|
||||
stop_strings: List[str] = []
|
||||
custom_stopping_criteria: List[Union[GenerationStopper]] = []
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
||||
from docling.exceptions import OperationNotAllowed
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.api_image_request import api_image_request
|
||||
from docling.models.utils.generation_utils import GenerationStopper
|
||||
from docling.utils.api_image_request import (
|
||||
api_image_request,
|
||||
api_image_request_streaming,
|
||||
)
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
|
||||
@@ -41,7 +47,7 @@ class ApiVlmModel(BasePageModel):
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
return page
|
||||
else:
|
||||
|
||||
with TimeRecorder(conv_res, "vlm"):
|
||||
assert page.size is not None
|
||||
|
||||
@@ -49,11 +55,35 @@ class ApiVlmModel(BasePageModel):
|
||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||
)
|
||||
assert hi_res_image is not None
|
||||
if hi_res_image:
|
||||
if hi_res_image.mode != "RGB":
|
||||
if hi_res_image and hi_res_image.mode != "RGB":
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
|
||||
prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||
|
||||
if self.vlm_options.custom_stopping_criteria:
|
||||
# Instantiate any GenerationStopper classes before passing to streaming
|
||||
instantiated_stoppers = []
|
||||
for criteria in self.vlm_options.custom_stopping_criteria:
|
||||
if isinstance(criteria, GenerationStopper):
|
||||
instantiated_stoppers.append(criteria)
|
||||
elif isinstance(criteria, type) and issubclass(
|
||||
criteria, GenerationStopper
|
||||
):
|
||||
instantiated_stoppers.append(criteria())
|
||||
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
||||
|
||||
# Streaming path with early abort support
|
||||
page_tags = api_image_request_streaming(
|
||||
image=hi_res_image,
|
||||
prompt=prompt,
|
||||
url=self.vlm_options.url,
|
||||
timeout=self.timeout,
|
||||
headers=self.vlm_options.headers,
|
||||
generation_stoppers=instantiated_stoppers,
|
||||
**self.params,
|
||||
)
|
||||
else:
|
||||
# Non-streaming fallback (existing behavior)
|
||||
page_tags = api_image_request(
|
||||
image=hi_res_image,
|
||||
prompt=prompt,
|
||||
@@ -65,7 +95,6 @@ class ApiVlmModel(BasePageModel):
|
||||
|
||||
page_tags = self.vlm_options.decode_response(page_tags)
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||
|
||||
return page
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
||||
|
||||
@@ -103,7 +103,7 @@ class ReadingOrderModel:
|
||||
else:
|
||||
doc.add_text(parent=doc_item, label=c_label, text=c_text, prov=c_prov)
|
||||
|
||||
def _readingorder_elements_to_docling_doc( # noqa: C901
|
||||
def _readingorder_elements_to_docling_doc(
|
||||
self,
|
||||
conv_res: ConversionResult,
|
||||
ro_elements: List[ReadingOrderPageElement],
|
||||
|
||||
157
docling/models/utils/generation_utils.py
Normal file
157
docling/models/utils/generation_utils.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerationStopper:
|
||||
"""
|
||||
Base interface for stopping logic.
|
||||
- should_stop(s): True to stop given the current decoded text window.
|
||||
- lookback_tokens(): how many tokens should be considered (default: sys.maxsize).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self, s: str) -> bool:
|
||||
pass
|
||||
|
||||
def lookback_tokens(self) -> int:
|
||||
return sys.maxsize
|
||||
|
||||
|
||||
class DocTagsRepetitionStopper(GenerationStopper):
|
||||
"""
|
||||
Detects repetitive <tag>...<loc_x><loc_y><loc_w><loc_h>text</tag> blocks,
|
||||
but only when repeats are **consecutive** and both tag & inner text are identical.
|
||||
|
||||
Performance:
|
||||
- Heavy check runs every N calls (default 32).
|
||||
- Only decodes the last LOOKBACK_TOKENS tokens per sequence (default 200).
|
||||
"""
|
||||
|
||||
def __init__(self, *, N: int = 32, lookback_tokens: int = 200):
|
||||
self.N = max(1, int(N))
|
||||
self._lookback_tokens = max(1, int(lookback_tokens))
|
||||
self._call_count = 0
|
||||
|
||||
# <tag> ... <loc_x><loc_y><loc_w><loc_h> text ... </tag>
|
||||
self._PATTERN = re.compile(
|
||||
r"""
|
||||
<(?P<tag>[a-zA-Z0-9_]+)>\s*
|
||||
(?P<prefix>.*?)?
|
||||
<loc_(?P<x>\d+)><loc_(?P<y>\d+)><loc_(?P<w>\d+)><loc_(?P<h>\d+)>
|
||||
(?P<text>.*?)
|
||||
</(?P=tag)>
|
||||
""",
|
||||
re.DOTALL | re.VERBOSE,
|
||||
)
|
||||
|
||||
# --- small helper ---
|
||||
def _regular(self, vals: List[int]) -> bool:
|
||||
"""3+ strictly increasing values with ~regular spacing (±20%)."""
|
||||
if len(vals) < 3:
|
||||
return False
|
||||
diffs = [b - a for a, b in zip(vals, vals[1:])]
|
||||
if any(d <= 0 for d in diffs):
|
||||
return False
|
||||
mean = sum(diffs) / len(diffs)
|
||||
tol = 0.2 * mean
|
||||
return all(abs(d - mean) <= tol for d in diffs)
|
||||
|
||||
def should_stop(self, s: str) -> bool:
|
||||
"""
|
||||
Trip only on **consecutive** runs (no other matched blocks between) of ≥3 items
|
||||
with the same <tag> and identical inner text, where within that run we see:
|
||||
- any exact duplicate (x,y,w,h), or
|
||||
- stable X/W with regular Y progression, or
|
||||
- stable Y/H with regular X progression.
|
||||
"""
|
||||
# Stream matches and evaluate runs on-the-fly to stay compact and fast.
|
||||
prev_tag = prev_text = None
|
||||
run = [] # list of (x,y,w,h)
|
||||
|
||||
def run_repetitive(boxes: List[tuple]) -> bool:
|
||||
if len(boxes) < 3:
|
||||
return False
|
||||
# duplicates?
|
||||
if len(set(boxes)) < len(boxes):
|
||||
return True
|
||||
xs, ys, ws, hs = zip(*boxes)
|
||||
x_stable = all(x == xs[0] for x in xs)
|
||||
y_stable = all(y == ys[0] for y in ys)
|
||||
w_stable = all(w == ws[0] for w in ws)
|
||||
h_stable = all(h == hs[0] for h in hs)
|
||||
# horizontal (down the page): X/W stable, Y regular
|
||||
if (x_stable or w_stable) and self._regular(list(ys)):
|
||||
return True
|
||||
# vertical (across): Y/H stable, X regular
|
||||
if (y_stable or h_stable) and self._regular(list(xs)):
|
||||
return True
|
||||
return False
|
||||
|
||||
for m in self._PATTERN.finditer(s):
|
||||
tag, text = m.group("tag"), m.group("text")
|
||||
box = (
|
||||
int(m.group("x")),
|
||||
int(m.group("y")),
|
||||
int(m.group("w")),
|
||||
int(m.group("h")),
|
||||
)
|
||||
|
||||
if prev_tag == tag and prev_text == text:
|
||||
run.append(box) # consecutive same-tag+text
|
||||
else:
|
||||
# evaluate previous run before starting a new one
|
||||
if run_repetitive(run):
|
||||
return True
|
||||
prev_tag, prev_text = tag, text
|
||||
run = [box]
|
||||
|
||||
# check the last run
|
||||
return run_repetitive(run)
|
||||
|
||||
|
||||
class HFStoppingCriteriaWrapper(StoppingCriteria):
|
||||
"""
|
||||
Adapts any GenerationStopper to HuggingFace Transformers.
|
||||
Decodes exactly min(seq_len, stopper.lookback_tokens()) tokens from the end.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
stopper: GenerationStopper,
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.stopper = stopper
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
||||
lb = max(1, int(self.stopper.lookback_tokens()))
|
||||
for seq in input_ids: # (batch, seq_len)
|
||||
window = seq[-lb:] # slicing handles lb > len(seq)
|
||||
try:
|
||||
text = self.tokenizer.decode(
|
||||
window, skip_special_tokens=self.skip_special_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
_log.info(f"Decoding failed for stopping check: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if self.stopper.should_stop(text):
|
||||
_log.info(
|
||||
"HF wrapper: stopping due to TextStopper.should_stop==True"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
_log.info(f"Error in TextStopper.should_stop: {e}")
|
||||
continue
|
||||
return False
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from transformers import StoppingCriteriaList, StopStringCriteria
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
@@ -20,6 +20,10 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
from docling.models.base_model import BaseVlmPageModel
|
||||
from docling.models.utils.generation_utils import (
|
||||
GenerationStopper,
|
||||
HFStoppingCriteriaWrapper,
|
||||
)
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@@ -253,15 +257,48 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# -- Optional stopping criteria
|
||||
stopping_criteria = None
|
||||
stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList()
|
||||
|
||||
# Add string-based stopping criteria
|
||||
if self.vlm_options.stop_strings:
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
[
|
||||
stopping_criteria_list.append(
|
||||
StopStringCriteria(
|
||||
stop_strings=self.vlm_options.stop_strings,
|
||||
tokenizer=self.processor.tokenizer,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Add custom stopping criteria
|
||||
if self.vlm_options.custom_stopping_criteria:
|
||||
for criteria in self.vlm_options.custom_stopping_criteria:
|
||||
# If it's a class (not an instance), determine the type and handle accordingly
|
||||
if isinstance(criteria, type):
|
||||
# Check if it's a GenerationStopper class
|
||||
if issubclass(criteria, GenerationStopper):
|
||||
# Instantiate GenerationStopper and wrap it
|
||||
stopper_instance = criteria()
|
||||
wrapped_criteria = HFStoppingCriteriaWrapper(
|
||||
self.processor.tokenizer, stopper_instance
|
||||
)
|
||||
stopping_criteria_list.append(wrapped_criteria)
|
||||
elif issubclass(criteria, StoppingCriteria):
|
||||
# It's a StoppingCriteria class, instantiate with tokenizer
|
||||
criteria_instance = criteria(self.processor.tokenizer)
|
||||
stopping_criteria_list.append(criteria_instance)
|
||||
elif isinstance(criteria, GenerationStopper):
|
||||
# Wrap GenerationStopper instances in HFStoppingCriteriaWrapper
|
||||
wrapped_criteria = HFStoppingCriteriaWrapper(
|
||||
self.processor.tokenizer, criteria
|
||||
)
|
||||
stopping_criteria_list.append(wrapped_criteria)
|
||||
else:
|
||||
# If it's already an instance of StoppingCriteria, use it directly
|
||||
stopping_criteria_list.append(criteria)
|
||||
|
||||
stopping_criteria = (
|
||||
StoppingCriteriaList(stopping_criteria_list)
|
||||
if stopping_criteria_list
|
||||
else None
|
||||
)
|
||||
|
||||
# -- Filter out decoder-specific keys from extra_generation_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
@@ -7,6 +8,7 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BaseVlmPageModel
|
||||
from docling.models.utils.generation_utils import GenerationStopper
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@@ -69,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
self.config = load_config(artifacts_path)
|
||||
|
||||
# Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
|
||||
if self.vlm_options.custom_stopping_criteria:
|
||||
for criteria in self.vlm_options.custom_stopping_criteria:
|
||||
if isinstance(criteria, StoppingCriteria):
|
||||
raise ValueError(
|
||||
f"MLX models do not support HuggingFace StoppingCriteria instances. "
|
||||
f"Found {type(criteria).__name__}. Use GenerationStopper instead."
|
||||
)
|
||||
elif isinstance(criteria, type) and issubclass(
|
||||
criteria, StoppingCriteria
|
||||
):
|
||||
raise ValueError(
|
||||
f"MLX models do not support HuggingFace StoppingCriteria classes. "
|
||||
f"Found {criteria.__name__}. Use GenerationStopper instead."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@@ -193,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
|
||||
# Stream generate with stop strings support
|
||||
# Stream generate with stop strings and custom stopping criteria support
|
||||
start_time = time.time()
|
||||
_log.debug("start generating ...")
|
||||
|
||||
@@ -245,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
_log.debug("Stopping generation due to stop string match")
|
||||
break
|
||||
|
||||
# Check for custom stopping criteria (GenerationStopper instances)
|
||||
if self.vlm_options.custom_stopping_criteria:
|
||||
for criteria in self.vlm_options.custom_stopping_criteria:
|
||||
# Handle both instances and classes of GenerationStopper
|
||||
if isinstance(criteria, GenerationStopper):
|
||||
stopper = criteria
|
||||
elif isinstance(criteria, type) and issubclass(
|
||||
criteria, GenerationStopper
|
||||
):
|
||||
stopper = criteria()
|
||||
|
||||
# Determine the text window to check based on lookback_tokens
|
||||
lookback_tokens = stopper.lookback_tokens()
|
||||
# Check only the last N characters worth of text
|
||||
# This is a simplified approach - in practice, you might want to
|
||||
# decode the last N tokens from the token list for more accuracy
|
||||
text_to_check = (
|
||||
output[-lookback_tokens:]
|
||||
if len(output) > lookback_tokens
|
||||
else output
|
||||
)
|
||||
|
||||
try:
|
||||
if stopper.should_stop(text_to_check):
|
||||
_log.info(
|
||||
f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
_log.warning(
|
||||
f"Error in GenerationStopper.should_stop: {e}"
|
||||
)
|
||||
continue
|
||||
else: # note: for-else idiom
|
||||
continue # Only executed if the inner loop didn't break
|
||||
break # Break the outer loop if any stopper triggered
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
_log.debug(
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from docling.datamodel.base_models import OpenAiApiResponse
|
||||
from docling.models.utils.generation_utils import GenerationStopper
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,3 +61,107 @@ def api_image_request(
|
||||
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
||||
generated_text = api_resp.choices[0].message.content.strip()
|
||||
return generated_text
|
||||
|
||||
|
||||
def api_image_request_streaming(
|
||||
image: Image.Image,
|
||||
prompt: str,
|
||||
url: AnyUrl,
|
||||
*,
|
||||
timeout: float = 20,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
generation_stoppers: List[GenerationStopper] = [],
|
||||
**params,
|
||||
) -> str:
|
||||
"""
|
||||
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
|
||||
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
|
||||
Accumulates text and calls stopper.should_stop(window) as chunks arrive.
|
||||
If stopper triggers, the HTTP connection is closed to abort server-side generation.
|
||||
"""
|
||||
img_io = BytesIO()
|
||||
image.save(img_io, "PNG")
|
||||
image_b64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"stream": True, # <-- critical for SSE streaming
|
||||
**params,
|
||||
}
|
||||
|
||||
# Debug: Log the payload to verify temperature is included
|
||||
_log.debug(f"API streaming request payload: {json.dumps(payload, indent=2)}")
|
||||
|
||||
# Some servers require Accept: text/event-stream for SSE.
|
||||
# It's safe to set it; OpenAI-compatible servers tolerate it.
|
||||
hdrs = {"Accept": "text/event-stream", **(headers or {})}
|
||||
|
||||
# Try to force temperature via header if server ignores payload parameter
|
||||
if "temperature" in params:
|
||||
hdrs["X-Temperature"] = str(params["temperature"])
|
||||
|
||||
# Stream the HTTP response
|
||||
with requests.post(
|
||||
str(url), headers=hdrs, json=payload, timeout=timeout, stream=True
|
||||
) as r:
|
||||
if not r.ok:
|
||||
_log.error(
|
||||
f"Error calling the API {url} in streaming mode. Response was {r.text}"
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
full_text = []
|
||||
for raw_line in r.iter_lines(decode_unicode=True):
|
||||
if not raw_line: # keep-alives / blank lines
|
||||
continue
|
||||
if not raw_line.startswith("data:"):
|
||||
# Some proxies inject comments; ignore anything not starting with 'data:'
|
||||
continue
|
||||
|
||||
data = raw_line[len("data:") :].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
obj = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
_log.debug("Skipping non-JSON SSE chunk: %r", data[:200])
|
||||
continue
|
||||
|
||||
# OpenAI-compatible delta format
|
||||
# obj["choices"][0]["delta"]["content"] may be None or missing (e.g., tool calls)
|
||||
try:
|
||||
delta = obj["choices"][0].get("delta") or {}
|
||||
piece = delta.get("content") or ""
|
||||
except (KeyError, IndexError) as e:
|
||||
_log.debug("Unexpected SSE chunk shape: %s", e)
|
||||
piece = ""
|
||||
|
||||
if piece:
|
||||
full_text.append(piece)
|
||||
for stopper in generation_stoppers:
|
||||
# Respect stopper's lookback window. We use a simple string window which
|
||||
# works with the GenerationStopper interface.
|
||||
lookback = max(1, stopper.lookback_tokens())
|
||||
window = "".join(full_text)[-lookback:]
|
||||
if stopper.should_stop(window):
|
||||
# Break out of the loop cleanly. The context manager will handle
|
||||
# closing the connection when we exit the 'with' block.
|
||||
# vLLM/OpenAI-compatible servers will detect the client disconnect
|
||||
# and abort the request server-side.
|
||||
return "".join(full_text)
|
||||
|
||||
return "".join(full_text)
|
||||
|
||||
108
docs/examples/granitedocling_repetition_stopping.py
vendored
Normal file
108
docs/examples/granitedocling_repetition_stopping.py
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
# %% [markdown]
|
||||
# Experimental VLM pipeline with custom repetition stopping criteria.
|
||||
#
|
||||
# This script demonstrates the use of custom stopping criteria that detect
|
||||
# repetitive location coordinate patterns in generated text and stop generation
|
||||
# when such patterns are found.
|
||||
#
|
||||
# What this example does
|
||||
# - Uses the GraniteDocling model with custom repetition stopping criteria injected
|
||||
# - Processes a PDF document or image and monitors for repetitive coordinate patterns
|
||||
# - Stops generation early when repetitive patterns are detected
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
import logging
|
||||
|
||||
from docling.datamodel import vlm_model_specs
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import VlmPipelineOptions
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.models.utils.generation_utils import (
|
||||
DocTagsRepetitionStopper,
|
||||
)
|
||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
|
||||
|
||||
|
||||
# Set up logging to see when repetition stopping is triggered
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Replace with a local path if preferred.
|
||||
# source = "https://ibm.biz/docling-page-with-table" # Example that shows no repetitions.
|
||||
source = "tests/data_scanned/old_newspaper.png" # Example that creates repetitions.
|
||||
print(f"Processing document: {source}")
|
||||
|
||||
###### USING GRANITEDOCLING WITH CUSTOM REPETITION STOPPING
|
||||
|
||||
## Using standard Huggingface Transformers (most portable, slowest)
|
||||
custom_vlm_options = vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.model_copy()
|
||||
|
||||
# Uncomment this to use MLX-accelerated version on Apple Silicon
|
||||
# custom_vlm_options = vlm_model_specs.GRANITEDOCLING_MLX.model_copy() # use this for Apple Silicon
|
||||
|
||||
|
||||
# Create custom VLM options with repetition stopping criteria
|
||||
custom_vlm_options.custom_stopping_criteria = [
|
||||
DocTagsRepetitionStopper(N=32)
|
||||
] # check for repetitions for every 32 new tokens decoded.
|
||||
|
||||
pipeline_options = VlmPipelineOptions(
|
||||
vlm_options=custom_vlm_options,
|
||||
)
|
||||
|
||||
converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.IMAGE: PdfFormatOption(
|
||||
pipeline_cls=VlmPipeline,
|
||||
pipeline_options=pipeline_options,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
doc = converter.convert(source=source).document
|
||||
|
||||
print(doc.export_to_markdown())
|
||||
|
||||
## Using a remote VLM inference service (for example VLLM) - uncomment to use
|
||||
|
||||
# custom_vlm_options = ApiVlmOptions(
|
||||
# url="http://localhost:8000/v1/chat/completions", # LM studio defaults to port 1234, VLLM to 8000
|
||||
# params=dict(
|
||||
# model=vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.repo_id,
|
||||
# max_tokens=8192,
|
||||
# skip_special_tokens=True, # needed for VLLM
|
||||
# ),
|
||||
# headers={
|
||||
# "Authorization": "Bearer YOUR_API_KEY",
|
||||
# },
|
||||
# prompt=vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.prompt,
|
||||
# timeout=90,
|
||||
# scale=2.0,
|
||||
# temperature=0.0,
|
||||
# response_format=ResponseFormat.DOCTAGS,
|
||||
# custom_stopping_criteria=[
|
||||
# DocTagsRepetitionStopper(N=1)
|
||||
# ], # check for repetitions for every new chunk of the response stream
|
||||
# )
|
||||
|
||||
|
||||
# pipeline_options = VlmPipelineOptions(
|
||||
# vlm_options=custom_vlm_options,
|
||||
# enable_remote_services=True, # required when using a remote inference service.
|
||||
# )
|
||||
|
||||
# converter = DocumentConverter(
|
||||
# format_options={
|
||||
# InputFormat.IMAGE: PdfFormatOption(
|
||||
# pipeline_cls=VlmPipeline,
|
||||
# pipeline_options=pipeline_options,
|
||||
# ),
|
||||
# }
|
||||
# )
|
||||
|
||||
# doc = converter.convert(source=source).document
|
||||
|
||||
# print(doc.export_to_markdown())
|
||||
@@ -217,7 +217,7 @@ classmethod-decorators = [
|
||||
"tests/*.py" = ["ASYNC"] # Disable ASYNC check for tests
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 20
|
||||
max-complexity = 30
|
||||
|
||||
# [tool.ruff.lint.isort.sections]
|
||||
# "docling" = ["docling_core", "docling_ibm_models", "docling_parse"]
|
||||
|
||||
BIN
tests/data_scanned/old_newspaper.png
vendored
Normal file
BIN
tests/data_scanned/old_newspaper.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.0 MiB |
Reference in New Issue
Block a user