diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index a24df89d..3d31d6a4 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -212,6 +212,12 @@ class OcrMacOptions(OcrOptions): class PictureDescriptionBaseOptions(BaseOptions): batch_size: int = 8 scale: float = 2 + text_context_window_size_before_picture: int = ( + 1 # Number of text items to consider before the image + ) + text_context_window_size_after_picture: int = ( + 0 # Number of text items to consider after the image + ) picture_area_threshold: float = ( 0.05 # percentage of the area for a picture to processed with the models diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index 44bb5e21..d56bbba5 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -1,6 +1,7 @@ +import logging from collections.abc import Iterable from pathlib import Path -from typing import Optional, Type, Union +from typing import Optional, Tuple, Type, Union from PIL import Image @@ -13,6 +14,8 @@ from docling.exceptions import OperationNotAllowed from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.utils.api_image_request import api_image_request +_log = logging.getLogger(__name__) + class PictureDescriptionApiModel(PictureDescriptionBaseModel): # elements_batch_size = 4 @@ -57,3 +60,22 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel): headers=self.options.headers, **self.options.params, ) + + def _annotate_with_context( + self, image_context_map: Iterable[Tuple[Image.Image, str]] + ) -> Iterable[str]: + # Note: technically we could make a batch request here, + # but not all APIs will allow for it. For example, vllm won't allow more than 1. + for image, context in image_context_map: + # Create context-aware prompt + context_prompt = f"{context}\n{self.options.prompt}" + _log.debug("Prompt: %s", context_prompt) + + yield api_image_request( + image=image, + prompt=context_prompt, + url=self.options.url, + timeout=self.options.timeout, + headers=self.options.headers, + **self.options.params, + ) diff --git a/docling/models/picture_description_base_model.py b/docling/models/picture_description_base_model.py index 2f6e6479..ab0d24d4 100644 --- a/docling/models/picture_description_base_model.py +++ b/docling/models/picture_description_base_model.py @@ -1,7 +1,8 @@ +import logging from abc import abstractmethod from collections.abc import Iterable from pathlib import Path -from typing import List, Optional, Type, Union +from typing import List, Optional, Tuple, Type, Union from docling_core.types.doc import ( DoclingDocument, @@ -23,6 +24,8 @@ from docling.models.base_model import ( ItemAndImageEnrichmentElement, ) +_log = logging.getLogger(__name__) + class PictureDescriptionBaseModel( BaseItemAndImageEnrichmentModel, BaseModelWithOptions @@ -48,6 +51,70 @@ class PictureDescriptionBaseModel( def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: raise NotImplementedError + def _annotate_with_context( + self, image_context_map: Iterable[Tuple[Image.Image, str]] + ) -> Iterable[str]: + """Override this method to support context in concrete implementations.""" + # Extract only the images (keys) from the dict + images = [image_context_pair[0] for image_context_pair in image_context_map] + # Default implementation ignores context + yield from self._annotate_images(images) + + def _get_surrounding_text( + self, doc: DoclingDocument, picture_item: PictureItem + ) -> str: + """Get text context from items before and after the picture.""" + context = [] + text_items_before_picture, text_items_after_picture = [], [] + found_picture = False + after_count = 0 + + _log.debug( + "Getting surrounding text for picture ref: %s", picture_item.self_ref + ) + for item, _ in doc.iterate_items(): + if item == picture_item: + found_picture = True + continue + + if not found_picture: # before picture + if isinstance(item, (str, NodeItem)): + text = item if isinstance(item, str) else getattr(item, "text", "") + if text and text.strip(): + # hold all text items before the picture + text_items_before_picture.append(text) + else: # after picture + if ( + isinstance(item, (str, NodeItem)) + and after_count + < self.options.text_context_window_size_after_picture + ): + text = item if isinstance(item, str) else getattr(item, "text", "") + if text and text.strip(): + text_items_after_picture.append(text) + after_count += 1 + + if after_count >= self.options.text_context_window_size_after_picture: + # Stop if we have reached the limit of text items after the picture + break + + # Combine text items before and after the picture + if self.options.text_context_window_size_before_picture > 0: + # get only the last N text items before the picture + context.extend( + text_items_before_picture[ + -self.options.text_context_window_size_before_picture : + ] + ) + + if self.options.text_context_window_size_after_picture > 0: + context.extend(text_items_after_picture) + + _log.debug("Context before picture: %s", text_items_before_picture) + _log.debug("Context after picture: %s", text_items_after_picture) + # Join the context with newlines + return "\n".join(context) + def __call__( self, doc: DoclingDocument, @@ -58,8 +125,9 @@ class PictureDescriptionBaseModel( yield element.item return - images: List[Image.Image] = [] - elements: List[PictureItem] = [] + image_context_map: List[Tuple[Image.Image, str]] = [] + pictures: List[PictureItem] = [] + for el in element_batch: assert isinstance(el.item, PictureItem) describe_image = True @@ -74,16 +142,31 @@ class PictureDescriptionBaseModel( if area_fraction < self.options.picture_area_threshold: describe_image = False if describe_image: - elements.append(el.item) - images.append(el.image) + pictures.append(el.item) + context = "" + if ( + self.options.text_context_window_size_before_picture > 0 + or self.options.text_context_window_size_after_picture > 0 + ): + # Get the surrounding text context + context = self._get_surrounding_text(doc, el.item) + image_context_map.append((el.image, context)) - outputs = self._annotate_images(images) - - for item, output in zip(elements, outputs): - item.annotations.append( - PictureDescriptionData(text=output, provenance=self.provenance) + if ( + self.options.text_context_window_size_before_picture > 0 + or self.options.text_context_window_size_after_picture > 0 + ): + picture_descriptions = self._annotate_with_context(image_context_map) + else: + picture_descriptions = self._annotate_images( + image for image, _ in image_context_map ) - yield item + + for picture, description in zip(pictures, picture_descriptions): + picture.annotations.append( + PictureDescriptionData(text=description, provenance=self.provenance) + ) + yield picture @classmethod @abstractmethod diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py index 679e80c2..bcded748 100644 --- a/docling/models/picture_description_vlm_model.py +++ b/docling/models/picture_description_vlm_model.py @@ -1,6 +1,7 @@ +import logging from collections.abc import Iterable from pathlib import Path -from typing import Optional, Type, Union +from typing import Optional, Tuple, Type, Union from PIL import Image @@ -12,6 +13,8 @@ from docling.datamodel.pipeline_options import ( from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.utils.accelerator_utils import decide_device +_log = logging.getLogger(__name__) + class PictureDescriptionVlmModel(PictureDescriptionBaseModel): @classmethod @@ -121,3 +124,42 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel): ) yield generated_texts[0].strip() + + def _annotate_with_context( + self, image_context_map: Iterable[Tuple[Image.Image, str]] + ) -> Iterable[str]: + from transformers import GenerationConfig + + for image, context in image_context_map: + # Create input messages with context + context_prompt = f"{context}\n{self.options.prompt}" + _log.debug("Prompt: %s", context_prompt) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": context_prompt}, + ], + }, + ] + + # Prepare inputs + prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True + ) + inputs = self.processor(text=prompt, images=[image], return_tensors="pt") + inputs = inputs.to(self.device) + + # Generate outputs + generated_ids = self.model.generate( + **inputs, + generation_config=GenerationConfig(**self.options.generation_config), + ) + generated_texts = self.processor.batch_decode( + generated_ids[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + ) + + yield generated_texts[0].strip() diff --git a/pyproject.toml b/pyproject.toml index abbe4165..6cea15ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -283,3 +283,9 @@ branch = "main" parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test" parser_angular_minor_types = "feat" parser_angular_patch_types = "fix,perf" + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" \ No newline at end of file diff --git a/tests/test_picture_description.py b/tests/test_picture_description.py new file mode 100644 index 00000000..42551aba --- /dev/null +++ b/tests/test_picture_description.py @@ -0,0 +1,150 @@ +import logging +from pathlib import Path + +import pytest +import requests +from docling_core.types.doc.document import PictureDescriptionData + +from docling.datamodel.base_models import ConversionStatus, InputFormat +from docling.datamodel.pipeline_options import ( + PdfPipelineOptions, + PictureDescriptionApiOptions, + PictureDescriptionVlmOptions, +) +from docling.document_converter import DocumentConverter, PdfFormatOption + +# Configure logging at the top of the file +logging.basicConfig(level=logging.DEBUG) +_log = logging.getLogger(__name__) + +IMAGE_RESOLUTION_SCALE = 2.0 +LOCAL_VISION_MODEL = "ibm-granite/granite-vision-3.2-2b" +# LOCAL_VISION_MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct" +API_VISION_MODEL = "granite3.2-vision:2b" +REMOTE_CHAT_API_URL = "http://localhost:8321/v1/openai/v1/chat/completions" # for llama-stack OpenAI API interface +DOC_SOURCE = "https://www.allspringglobal.com/globalassets/assets/regulatory/summary-prospectus/emerging-markets-equity-summ.pdf" +PROMPT = ( + "Please describe the image using the text above as additional context. " + "Additionally, if only the image contains a chart (like bar chat, pie chat, line chat, etc.), " + "please try to extract a list of data points (percentages, numbers, etc) that are depicted in the chart. " + "Also, based on the type of information extracted, " + "when applicable try to summarize it using bullet points or even a tabular representation using markdown if possible." +) + + +def is_api_available(url: str, timeout: int = 3) -> bool: + try: + requests.get(url, timeout=timeout) + return True + except (requests.ConnectionError, requests.Timeout) as e: + _log.debug(f"API endpoint {url} is not reachable: {e!s}") + return False + + +def process_document(pipeline_options: PdfPipelineOptions): + # Initialize document converter + doc_converter = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options) + } + ) + + # Convert test document + _log.info(f"Converting {DOC_SOURCE} with VLM API") + conversion_result = doc_converter.convert(source=DOC_SOURCE) + + # Basic conversion checks + assert conversion_result.status == ConversionStatus.SUCCESS + doc = conversion_result.document + assert doc is not None + + # Verify pictures were processed + assert len(doc.pictures) > 0 + + # Check each picture for descriptions + for picture in doc.pictures: + # Not every picture has a annotations (eg. some pictures are too small (based on the threshold param (5% of the page area by default)) + # and gets ignored by the conversion Pipeline) + if len(picture.annotations) > 0: + # Get the description + descriptions = [ + ann + for ann in picture.annotations + if isinstance(ann, PictureDescriptionData) + ] + assert len(descriptions) > 0 + + # Verify each description is non-empty + for desc in descriptions: + assert isinstance(desc.text, str) + assert len(desc.text) > 0 + _log.info( + f"\nPicture ref: {picture.get_ref().cref}, page #{picture.prov[0].page_no}" + ) + _log.info(f"\tGenerated description: {desc.text}") + else: + _log.info( + f"Picture {picture.get_ref().cref} has no annotations (too small?)" + ) + + +@pytest.mark.skipif( + not is_api_available(REMOTE_CHAT_API_URL), + reason="Remote API endpoint is not accessible", +) +def test_picture_description_context_api_integration(): + """Test that the context windows functionality works correctly in the picture description pipeline using a VLM served via API""" + # Setup pipeline options with context windows + pipeline_options = PdfPipelineOptions( + images_scale=IMAGE_RESOLUTION_SCALE, + do_picture_description=True, + generate_picture_images=True, + enable_remote_services=True, + picture_description_options=PictureDescriptionApiOptions( + url=REMOTE_CHAT_API_URL, + params=dict(model=API_VISION_MODEL), + text_context_window_size_before_picture=2, # Get 2 text items before + text_context_window_size_after_picture=1, # Get 1 text item after + prompt=PROMPT, + timeout=90, + ), + ) + + process_document(pipeline_options) + + +def test_picture_description_context_vlm_integration(): + """Test that the context windows functionality works correctly in the picture description pipeline""" + # Setup pipeline options with context windows + pipeline_options = PdfPipelineOptions( + images_scale=IMAGE_RESOLUTION_SCALE, + generate_page_images=True, + do_picture_description=True, + generate_picture_images=True, + picture_description_options=PictureDescriptionVlmOptions( + repo_id=LOCAL_VISION_MODEL, + text_context_window_size_before_picture=2, # Get 2 text items before + text_context_window_size_after_picture=1, # Get 1 text item after + prompt=PROMPT, + ), + ) + + process_document(pipeline_options) + + +def test_picture_description_no_context_vlm_integration(): + """Test that the picture description works without context windows""" + # Setup pipeline options without context windows + pipeline_options = PdfPipelineOptions( + images_scale=IMAGE_RESOLUTION_SCALE, + do_picture_description=True, + generate_picture_images=True, + picture_description_options=PictureDescriptionVlmOptions( + repo_id=LOCAL_VISION_MODEL, + text_context_window_size_before_picture=0, # No text context + text_context_window_size_after_picture=0, # No text context + prompt=PROMPT, + ), + ) + + process_document(pipeline_options)