diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index ddc06f28..078f68ba 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -60,7 +60,7 @@ jobs: run: | for file in docs/examples/*.py; do # Skip batch_convert.py - if [[ "$(basename "$file")" =~ ^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then + if [[ "$(basename "$file")" =~ ^(batch_convert|granitedocling_repetition_stopping|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then echo "Skipping $file" continue fi diff --git a/docling/backend/asciidoc_backend.py b/docling/backend/asciidoc_backend.py index 859646ee..0bebb1c5 100644 --- a/docling/backend/asciidoc_backend.py +++ b/docling/backend/asciidoc_backend.py @@ -78,7 +78,7 @@ class AsciiDocBackend(DeclarativeDocumentBackend): return doc - def _parse(self, doc: DoclingDocument): # noqa: C901 + def _parse(self, doc: DoclingDocument): """ Main function that orchestrates the parsing by yielding components: title, section headers, text, lists, and tables. diff --git a/docling/backend/msword_backend.py b/docling/backend/msword_backend.py index 53653f5e..6bc9b906 100644 --- a/docling/backend/msword_backend.py +++ b/docling/backend/msword_backend.py @@ -812,7 +812,7 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend): else prev_parent ) - def _handle_text_elements( # noqa: C901 + def _handle_text_elements( self, element: BaseOxmlElement, docx_obj: DocxDocument, diff --git a/docling/backend/xml/jats_backend.py b/docling/backend/xml/jats_backend.py index 24a41b56..f14e2d03 100755 --- a/docling/backend/xml/jats_backend.py +++ b/docling/backend/xml/jats_backend.py @@ -352,7 +352,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend): return - def _parse_element_citation(self, node: etree._Element) -> str: # noqa: C901 + def _parse_element_citation(self, node: etree._Element) -> str: citation: Citation = { "author_names": "", "title": "", @@ -538,7 +538,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend): return @staticmethod - def parse_table_data(element: Tag) -> Optional[TableData]: # noqa: C901 + def parse_table_data(element: Tag) -> Optional[TableData]: # TODO, see how to implement proper support for rich tables from HTML backend nested_tables = element.find("table") if nested_tables is not None: @@ -713,7 +713,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend): ) return - def _walk_linear( # noqa: C901 + def _walk_linear( self, doc: DoclingDocument, parent: NodeItem, node: etree._Element ) -> str: skip_tags = ["term"] diff --git a/docling/backend/xml/uspto_backend.py b/docling/backend/xml/uspto_backend.py index 268b80ad..0569a5fb 100644 --- a/docling/backend/xml/uspto_backend.py +++ b/docling/backend/xml/uspto_backend.py @@ -1523,7 +1523,7 @@ class XmlTable: return ncols_max - def _parse_table(self, table: Tag) -> TableData: # noqa: C901 + def _parse_table(self, table: Tag) -> TableData: """Parse the content of a table tag. Args: diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index 8502822e..02014c34 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -1,11 +1,13 @@ from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from docling_core.types.doc.page import SegmentedPage -from pydantic import AnyUrl, BaseModel +from pydantic import AnyUrl, BaseModel, ConfigDict +from transformers import StoppingCriteria from typing_extensions import deprecated from docling.datamodel.accelerator_options import AcceleratorDevice +from docling.models.utils.generation_utils import GenerationStopper class BaseVlmOptions(BaseModel): @@ -50,6 +52,8 @@ class TransformersPromptStyle(str, Enum): class InlineVlmOptions(BaseVlmOptions): + model_config = ConfigDict(arbitrary_types_allowed=True) + kind: Literal["inline_model_options"] = "inline_model_options" repo_id: str @@ -72,6 +76,7 @@ class InlineVlmOptions(BaseVlmOptions): ] stop_strings: List[str] = [] + custom_stopping_criteria: List[Union[StoppingCriteria, GenerationStopper]] = [] extra_generation_config: Dict[str, Any] = {} extra_processor_kwargs: Dict[str, Any] = {} @@ -89,6 +94,8 @@ class HuggingFaceVlmOptions(InlineVlmOptions): class ApiVlmOptions(BaseVlmOptions): + model_config = ConfigDict(arbitrary_types_allowed=True) + kind: Literal["api_model_options"] = "api_model_options" url: AnyUrl = AnyUrl( @@ -99,3 +106,6 @@ class ApiVlmOptions(BaseVlmOptions): timeout: float = 60 concurrency: int = 1 response_format: ResponseFormat + + stop_strings: List[str] = [] + custom_stopping_criteria: List[Union[GenerationStopper]] = [] diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index c48aa0bc..bcdb97f5 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -1,12 +1,18 @@ from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor +from transformers import StoppingCriteria + from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions from docling.exceptions import OperationNotAllowed from docling.models.base_model import BasePageModel -from docling.utils.api_image_request import api_image_request +from docling.models.utils.generation_utils import GenerationStopper +from docling.utils.api_image_request import ( + api_image_request, + api_image_request_streaming, +) from docling.utils.profiling import TimeRecorder @@ -41,19 +47,43 @@ class ApiVlmModel(BasePageModel): assert page._backend is not None if not page._backend.is_valid(): return page - else: - with TimeRecorder(conv_res, "vlm"): - assert page.size is not None - hi_res_image = page.get_image( - scale=self.vlm_options.scale, max_size=self.vlm_options.max_size + with TimeRecorder(conv_res, "vlm"): + assert page.size is not None + + hi_res_image = page.get_image( + scale=self.vlm_options.scale, max_size=self.vlm_options.max_size + ) + assert hi_res_image is not None + if hi_res_image and hi_res_image.mode != "RGB": + hi_res_image = hi_res_image.convert("RGB") + + prompt = self.vlm_options.build_prompt(page.parsed_page) + + if self.vlm_options.custom_stopping_criteria: + # Instantiate any GenerationStopper classes before passing to streaming + instantiated_stoppers = [] + for criteria in self.vlm_options.custom_stopping_criteria: + if isinstance(criteria, GenerationStopper): + instantiated_stoppers.append(criteria) + elif isinstance(criteria, type) and issubclass( + criteria, GenerationStopper + ): + instantiated_stoppers.append(criteria()) + # Skip non-GenerationStopper criteria (should have been caught in validation) + + # Streaming path with early abort support + page_tags = api_image_request_streaming( + image=hi_res_image, + prompt=prompt, + url=self.vlm_options.url, + timeout=self.timeout, + headers=self.vlm_options.headers, + generation_stoppers=instantiated_stoppers, + **self.params, ) - assert hi_res_image is not None - if hi_res_image: - if hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") - - prompt = self.vlm_options.build_prompt(page.parsed_page) + else: + # Non-streaming fallback (existing behavior) page_tags = api_image_request( image=hi_res_image, prompt=prompt, @@ -63,10 +93,9 @@ class ApiVlmModel(BasePageModel): **self.params, ) - page_tags = self.vlm_options.decode_response(page_tags) - page.predictions.vlm_response = VlmPrediction(text=page_tags) - - return page + page_tags = self.vlm_options.decode_response(page_tags) + page.predictions.vlm_response = VlmPrediction(text=page_tags) + return page with ThreadPoolExecutor(max_workers=self.concurrency) as executor: yield from executor.map(_vlm_request, page_batch) diff --git a/docling/models/readingorder_model.py b/docling/models/readingorder_model.py index 375ad4e4..47d4bb31 100644 --- a/docling/models/readingorder_model.py +++ b/docling/models/readingorder_model.py @@ -103,7 +103,7 @@ class ReadingOrderModel: else: doc.add_text(parent=doc_item, label=c_label, text=c_text, prov=c_prov) - def _readingorder_elements_to_docling_doc( # noqa: C901 + def _readingorder_elements_to_docling_doc( self, conv_res: ConversionResult, ro_elements: List[ReadingOrderPageElement], diff --git a/docling/models/utils/generation_utils.py b/docling/models/utils/generation_utils.py new file mode 100644 index 00000000..a6502911 --- /dev/null +++ b/docling/models/utils/generation_utils.py @@ -0,0 +1,157 @@ +import logging +import re +import sys +from abc import abstractmethod +from typing import List + +from transformers import StoppingCriteria + +_log = logging.getLogger(__name__) + + +class GenerationStopper: + """ + Base interface for stopping logic. + - should_stop(s): True to stop given the current decoded text window. + - lookback_tokens(): how many tokens should be considered (default: sys.maxsize). + """ + + @abstractmethod + def should_stop(self, s: str) -> bool: + pass + + def lookback_tokens(self) -> int: + return sys.maxsize + + +class DocTagsRepetitionStopper(GenerationStopper): + """ + Detects repetitive ...text blocks, + but only when repeats are **consecutive** and both tag & inner text are identical. + + Performance: + - Heavy check runs every N calls (default 32). + - Only decodes the last LOOKBACK_TOKENS tokens per sequence (default 200). + """ + + def __init__(self, *, N: int = 32, lookback_tokens: int = 200): + self.N = max(1, int(N)) + self._lookback_tokens = max(1, int(lookback_tokens)) + self._call_count = 0 + + # ... text ... + self._PATTERN = re.compile( + r""" + <(?P[a-zA-Z0-9_]+)>\s* + (?P.*?)? + \d+)>\d+)>\d+)>\d+)> + (?P.*?) + + """, + re.DOTALL | re.VERBOSE, + ) + + # --- small helper --- + def _regular(self, vals: List[int]) -> bool: + """3+ strictly increasing values with ~regular spacing (±20%).""" + if len(vals) < 3: + return False + diffs = [b - a for a, b in zip(vals, vals[1:])] + if any(d <= 0 for d in diffs): + return False + mean = sum(diffs) / len(diffs) + tol = 0.2 * mean + return all(abs(d - mean) <= tol for d in diffs) + + def should_stop(self, s: str) -> bool: + """ + Trip only on **consecutive** runs (no other matched blocks between) of ≥3 items + with the same and identical inner text, where within that run we see: + - any exact duplicate (x,y,w,h), or + - stable X/W with regular Y progression, or + - stable Y/H with regular X progression. + """ + # Stream matches and evaluate runs on-the-fly to stay compact and fast. + prev_tag = prev_text = None + run = [] # list of (x,y,w,h) + + def run_repetitive(boxes: List[tuple]) -> bool: + if len(boxes) < 3: + return False + # duplicates? + if len(set(boxes)) < len(boxes): + return True + xs, ys, ws, hs = zip(*boxes) + x_stable = all(x == xs[0] for x in xs) + y_stable = all(y == ys[0] for y in ys) + w_stable = all(w == ws[0] for w in ws) + h_stable = all(h == hs[0] for h in hs) + # horizontal (down the page): X/W stable, Y regular + if (x_stable or w_stable) and self._regular(list(ys)): + return True + # vertical (across): Y/H stable, X regular + if (y_stable or h_stable) and self._regular(list(xs)): + return True + return False + + for m in self._PATTERN.finditer(s): + tag, text = m.group("tag"), m.group("text") + box = ( + int(m.group("x")), + int(m.group("y")), + int(m.group("w")), + int(m.group("h")), + ) + + if prev_tag == tag and prev_text == text: + run.append(box) # consecutive same-tag+text + else: + # evaluate previous run before starting a new one + if run_repetitive(run): + return True + prev_tag, prev_text = tag, text + run = [box] + + # check the last run + return run_repetitive(run) + + +class HFStoppingCriteriaWrapper(StoppingCriteria): + """ + Adapts any GenerationStopper to HuggingFace Transformers. + Decodes exactly min(seq_len, stopper.lookback_tokens()) tokens from the end. + """ + + def __init__( + self, + tokenizer, + stopper: GenerationStopper, + *, + skip_special_tokens: bool = False, + ): + self.tokenizer = tokenizer + self.stopper = stopper + self.skip_special_tokens = skip_special_tokens + + def __call__(self, input_ids, scores, **kwargs) -> bool: + lb = max(1, int(self.stopper.lookback_tokens())) + for seq in input_ids: # (batch, seq_len) + window = seq[-lb:] # slicing handles lb > len(seq) + try: + text = self.tokenizer.decode( + window, skip_special_tokens=self.skip_special_tokens + ) + except Exception as e: + _log.info(f"Decoding failed for stopping check: {e}") + continue + + try: + if self.stopper.should_stop(text): + _log.info( + "HF wrapper: stopping due to TextStopper.should_stop==True" + ) + return True + except Exception as e: + _log.info(f"Error in TextStopper.should_stop: {e}") + continue + return False diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 25eb9b88..1f4d752c 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import numpy as np from PIL.Image import Image -from transformers import StoppingCriteriaList, StopStringCriteria +from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria from docling.datamodel.accelerator_options import ( AcceleratorOptions, @@ -20,6 +20,10 @@ from docling.datamodel.pipeline_options_vlm_model import ( TransformersPromptStyle, ) from docling.models.base_model import BaseVlmPageModel +from docling.models.utils.generation_utils import ( + GenerationStopper, + HFStoppingCriteriaWrapper, +) from docling.models.utils.hf_model_download import ( HuggingFaceModelDownloadMixin, ) @@ -253,17 +257,50 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload inputs = {k: v.to(self.device) for k, v in inputs.items()} # -- Optional stopping criteria - stopping_criteria = None + stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList() + + # Add string-based stopping criteria if self.vlm_options.stop_strings: - stopping_criteria = StoppingCriteriaList( - [ - StopStringCriteria( - stop_strings=self.vlm_options.stop_strings, - tokenizer=self.processor.tokenizer, - ) - ] + stopping_criteria_list.append( + StopStringCriteria( + stop_strings=self.vlm_options.stop_strings, + tokenizer=self.processor.tokenizer, + ) ) + # Add custom stopping criteria + if self.vlm_options.custom_stopping_criteria: + for criteria in self.vlm_options.custom_stopping_criteria: + # If it's a class (not an instance), determine the type and handle accordingly + if isinstance(criteria, type): + # Check if it's a GenerationStopper class + if issubclass(criteria, GenerationStopper): + # Instantiate GenerationStopper and wrap it + stopper_instance = criteria() + wrapped_criteria = HFStoppingCriteriaWrapper( + self.processor.tokenizer, stopper_instance + ) + stopping_criteria_list.append(wrapped_criteria) + elif issubclass(criteria, StoppingCriteria): + # It's a StoppingCriteria class, instantiate with tokenizer + criteria_instance = criteria(self.processor.tokenizer) + stopping_criteria_list.append(criteria_instance) + elif isinstance(criteria, GenerationStopper): + # Wrap GenerationStopper instances in HFStoppingCriteriaWrapper + wrapped_criteria = HFStoppingCriteriaWrapper( + self.processor.tokenizer, criteria + ) + stopping_criteria_list.append(wrapped_criteria) + else: + # If it's already an instance of StoppingCriteria, use it directly + stopping_criteria_list.append(criteria) + + stopping_criteria = ( + StoppingCriteriaList(stopping_criteria_list) + if stopping_criteria_list + else None + ) + # -- Filter out decoder-specific keys from extra_generation_config decoder_keys = { "skip_special_tokens", diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 1ee588c7..ac4cf9c8 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -1,4 +1,5 @@ import logging +import sys import threading import time from collections.abc import Iterable @@ -7,6 +8,7 @@ from typing import Optional, Union import numpy as np from PIL.Image import Image +from transformers import StoppingCriteria from docling.datamodel.accelerator_options import ( AcceleratorOptions, @@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.models.base_model import BaseVlmPageModel +from docling.models.utils.generation_utils import GenerationStopper from docling.models.utils.hf_model_download import ( HuggingFaceModelDownloadMixin, ) @@ -69,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): self.vlm_model, self.processor = load(artifacts_path) self.config = load_config(artifacts_path) + # Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria + if self.vlm_options.custom_stopping_criteria: + for criteria in self.vlm_options.custom_stopping_criteria: + if isinstance(criteria, StoppingCriteria): + raise ValueError( + f"MLX models do not support HuggingFace StoppingCriteria instances. " + f"Found {type(criteria).__name__}. Use GenerationStopper instead." + ) + elif isinstance(criteria, type) and issubclass( + criteria, StoppingCriteria + ): + raise ValueError( + f"MLX models do not support HuggingFace StoppingCriteria classes. " + f"Found {criteria.__name__}. Use GenerationStopper instead." + ) + def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -193,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): self.processor, self.config, user_prompt, num_images=1 ) - # Stream generate with stop strings support + # Stream generate with stop strings and custom stopping criteria support start_time = time.time() _log.debug("start generating ...") @@ -245,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): _log.debug("Stopping generation due to stop string match") break + # Check for custom stopping criteria (GenerationStopper instances) + if self.vlm_options.custom_stopping_criteria: + for criteria in self.vlm_options.custom_stopping_criteria: + # Handle both instances and classes of GenerationStopper + if isinstance(criteria, GenerationStopper): + stopper = criteria + elif isinstance(criteria, type) and issubclass( + criteria, GenerationStopper + ): + stopper = criteria() + + # Determine the text window to check based on lookback_tokens + lookback_tokens = stopper.lookback_tokens() + # Check only the last N characters worth of text + # This is a simplified approach - in practice, you might want to + # decode the last N tokens from the token list for more accuracy + text_to_check = ( + output[-lookback_tokens:] + if len(output) > lookback_tokens + else output + ) + + try: + if stopper.should_stop(text_to_check): + _log.info( + f"Stopping generation due to GenerationStopper: {type(stopper).__name__}" + ) + break + except Exception as e: + _log.warning( + f"Error in GenerationStopper.should_stop: {e}" + ) + continue + else: # note: for-else idiom + continue # Only executed if the inner loop didn't break + break # Break the outer loop if any stopper triggered + generation_time = time.time() - start_time _log.debug( diff --git a/docling/utils/api_image_request.py b/docling/utils/api_image_request.py index 9227389c..e85c6cad 100644 --- a/docling/utils/api_image_request.py +++ b/docling/utils/api_image_request.py @@ -1,13 +1,15 @@ import base64 +import json import logging from io import BytesIO -from typing import Dict, Optional +from typing import Dict, List, Optional import requests from PIL import Image from pydantic import AnyUrl from docling.datamodel.base_models import OpenAiApiResponse +from docling.models.utils.generation_utils import GenerationStopper _log = logging.getLogger(__name__) @@ -59,3 +61,107 @@ def api_image_request( api_resp = OpenAiApiResponse.model_validate_json(r.text) generated_text = api_resp.choices[0].message.content.strip() return generated_text + + +def api_image_request_streaming( + image: Image.Image, + prompt: str, + url: AnyUrl, + *, + timeout: float = 20, + headers: Optional[Dict[str, str]] = None, + generation_stoppers: List[GenerationStopper] = [], + **params, +) -> str: + """ + Stream a chat completion from an OpenAI-compatible server (e.g., vLLM). + Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'. + Accumulates text and calls stopper.should_stop(window) as chunks arrive. + If stopper triggers, the HTTP connection is closed to abort server-side generation. + """ + img_io = BytesIO() + image.save(img_io, "PNG") + image_b64 = base64.b64encode(img_io.getvalue()).decode("utf-8") + + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_b64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + payload = { + "messages": messages, + "stream": True, # <-- critical for SSE streaming + **params, + } + + # Debug: Log the payload to verify temperature is included + _log.debug(f"API streaming request payload: {json.dumps(payload, indent=2)}") + + # Some servers require Accept: text/event-stream for SSE. + # It's safe to set it; OpenAI-compatible servers tolerate it. + hdrs = {"Accept": "text/event-stream", **(headers or {})} + + # Try to force temperature via header if server ignores payload parameter + if "temperature" in params: + hdrs["X-Temperature"] = str(params["temperature"]) + + # Stream the HTTP response + with requests.post( + str(url), headers=hdrs, json=payload, timeout=timeout, stream=True + ) as r: + if not r.ok: + _log.error( + f"Error calling the API {url} in streaming mode. Response was {r.text}" + ) + r.raise_for_status() + + full_text = [] + for raw_line in r.iter_lines(decode_unicode=True): + if not raw_line: # keep-alives / blank lines + continue + if not raw_line.startswith("data:"): + # Some proxies inject comments; ignore anything not starting with 'data:' + continue + + data = raw_line[len("data:") :].strip() + if data == "[DONE]": + break + + try: + obj = json.loads(data) + except json.JSONDecodeError: + _log.debug("Skipping non-JSON SSE chunk: %r", data[:200]) + continue + + # OpenAI-compatible delta format + # obj["choices"][0]["delta"]["content"] may be None or missing (e.g., tool calls) + try: + delta = obj["choices"][0].get("delta") or {} + piece = delta.get("content") or "" + except (KeyError, IndexError) as e: + _log.debug("Unexpected SSE chunk shape: %s", e) + piece = "" + + if piece: + full_text.append(piece) + for stopper in generation_stoppers: + # Respect stopper's lookback window. We use a simple string window which + # works with the GenerationStopper interface. + lookback = max(1, stopper.lookback_tokens()) + window = "".join(full_text)[-lookback:] + if stopper.should_stop(window): + # Break out of the loop cleanly. The context manager will handle + # closing the connection when we exit the 'with' block. + # vLLM/OpenAI-compatible servers will detect the client disconnect + # and abort the request server-side. + return "".join(full_text) + + return "".join(full_text) diff --git a/docs/examples/granitedocling_repetition_stopping.py b/docs/examples/granitedocling_repetition_stopping.py new file mode 100644 index 00000000..673cb488 --- /dev/null +++ b/docs/examples/granitedocling_repetition_stopping.py @@ -0,0 +1,108 @@ +# %% [markdown] +# Experimental VLM pipeline with custom repetition stopping criteria. +# +# This script demonstrates the use of custom stopping criteria that detect +# repetitive location coordinate patterns in generated text and stop generation +# when such patterns are found. +# +# What this example does +# - Uses the GraniteDocling model with custom repetition stopping criteria injected +# - Processes a PDF document or image and monitors for repetitive coordinate patterns +# - Stops generation early when repetitive patterns are detected + + +# %% + +import logging + +from docling.datamodel import vlm_model_specs +from docling.datamodel.base_models import InputFormat +from docling.datamodel.pipeline_options import VlmPipelineOptions +from docling.document_converter import DocumentConverter, PdfFormatOption +from docling.models.utils.generation_utils import ( + DocTagsRepetitionStopper, +) +from docling.pipeline.vlm_pipeline import VlmPipeline + +logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + + +# Set up logging to see when repetition stopping is triggered +logging.basicConfig(level=logging.INFO) + +# Replace with a local path if preferred. +# source = "https://ibm.biz/docling-page-with-table" # Example that shows no repetitions. +source = "tests/data_scanned/old_newspaper.png" # Example that creates repetitions. +print(f"Processing document: {source}") + +###### USING GRANITEDOCLING WITH CUSTOM REPETITION STOPPING + +## Using standard Huggingface Transformers (most portable, slowest) +custom_vlm_options = vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.model_copy() + +# Uncomment this to use MLX-accelerated version on Apple Silicon +# custom_vlm_options = vlm_model_specs.GRANITEDOCLING_MLX.model_copy() # use this for Apple Silicon + + +# Create custom VLM options with repetition stopping criteria +custom_vlm_options.custom_stopping_criteria = [ + DocTagsRepetitionStopper(N=32) +] # check for repetitions for every 32 new tokens decoded. + +pipeline_options = VlmPipelineOptions( + vlm_options=custom_vlm_options, +) + +converter = DocumentConverter( + format_options={ + InputFormat.IMAGE: PdfFormatOption( + pipeline_cls=VlmPipeline, + pipeline_options=pipeline_options, + ), + } +) + +doc = converter.convert(source=source).document + +print(doc.export_to_markdown()) + +## Using a remote VLM inference service (for example VLLM) - uncomment to use + +# custom_vlm_options = ApiVlmOptions( +# url="http://localhost:8000/v1/chat/completions", # LM studio defaults to port 1234, VLLM to 8000 +# params=dict( +# model=vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.repo_id, +# max_tokens=8192, +# skip_special_tokens=True, # needed for VLLM +# ), +# headers={ +# "Authorization": "Bearer YOUR_API_KEY", +# }, +# prompt=vlm_model_specs.GRANITEDOCLING_TRANSFORMERS.prompt, +# timeout=90, +# scale=2.0, +# temperature=0.0, +# response_format=ResponseFormat.DOCTAGS, +# custom_stopping_criteria=[ +# DocTagsRepetitionStopper(N=1) +# ], # check for repetitions for every new chunk of the response stream +# ) + + +# pipeline_options = VlmPipelineOptions( +# vlm_options=custom_vlm_options, +# enable_remote_services=True, # required when using a remote inference service. +# ) + +# converter = DocumentConverter( +# format_options={ +# InputFormat.IMAGE: PdfFormatOption( +# pipeline_cls=VlmPipeline, +# pipeline_options=pipeline_options, +# ), +# } +# ) + +# doc = converter.convert(source=source).document + +# print(doc.export_to_markdown()) diff --git a/pyproject.toml b/pyproject.toml index 56342308..782d3302 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,7 +217,7 @@ classmethod-decorators = [ "tests/*.py" = ["ASYNC"] # Disable ASYNC check for tests [tool.ruff.lint.mccabe] -max-complexity = 20 +max-complexity = 30 # [tool.ruff.lint.isort.sections] # "docling" = ["docling_core", "docling_ibm_models", "docling_parse"] diff --git a/tests/data_scanned/old_newspaper.png b/tests/data_scanned/old_newspaper.png new file mode 100644 index 00000000..14d14b2e Binary files /dev/null and b/tests/data_scanned/old_newspaper.png differ