feat: Picture description using context with surrounding text

- Add the ability to use text items surrounding the picture as context to prompt the VLM.
- Implemented VLM-based picture description functionality
- Added ability to use text before and after pictures as context
- Added tests for both context and non-context approaches
- Included formatting fixes

Signed-off-by: Rafael T. C. Soares <rafaelcba@gmail.com>
This commit is contained in:
Rafael T. C. Soares 2025-04-24 18:27:16 -05:00
parent 23238c241f
commit d7922ab31d
6 changed files with 322 additions and 13 deletions

View File

@ -212,6 +212,12 @@ class OcrMacOptions(OcrOptions):
class PictureDescriptionBaseOptions(BaseOptions):
batch_size: int = 8
scale: float = 2
text_context_window_size_before_picture: int = (
1 # Number of text items to consider before the image
)
text_context_window_size_after_picture: int = (
0 # Number of text items to consider after the image
)
picture_area_threshold: float = (
0.05 # percentage of the area for a picture to processed with the models

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Optional, Type, Union
from typing import Optional, Tuple, Type, Union
from PIL import Image
@ -13,6 +14,8 @@ from docling.exceptions import OperationNotAllowed
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.utils.api_image_request import api_image_request
_log = logging.getLogger(__name__)
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# elements_batch_size = 4
@ -57,3 +60,22 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
headers=self.options.headers,
**self.options.params,
)
def _annotate_with_context(
self, image_context_map: Iterable[Tuple[Image.Image, str]]
) -> Iterable[str]:
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
for image, context in image_context_map:
# Create context-aware prompt
context_prompt = f"{context}\n{self.options.prompt}"
_log.debug("Prompt: %s", context_prompt)
yield api_image_request(
image=image,
prompt=context_prompt,
url=self.options.url,
timeout=self.options.timeout,
headers=self.options.headers,
**self.options.params,
)

View File

@ -1,7 +1,8 @@
import logging
from abc import abstractmethod
from collections.abc import Iterable
from pathlib import Path
from typing import List, Optional, Type, Union
from typing import List, Optional, Tuple, Type, Union
from docling_core.types.doc import (
DoclingDocument,
@ -23,6 +24,8 @@ from docling.models.base_model import (
ItemAndImageEnrichmentElement,
)
_log = logging.getLogger(__name__)
class PictureDescriptionBaseModel(
BaseItemAndImageEnrichmentModel, BaseModelWithOptions
@ -48,6 +51,70 @@ class PictureDescriptionBaseModel(
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
raise NotImplementedError
def _annotate_with_context(
self, image_context_map: Iterable[Tuple[Image.Image, str]]
) -> Iterable[str]:
"""Override this method to support context in concrete implementations."""
# Extract only the images (keys) from the dict
images = [image_context_pair[0] for image_context_pair in image_context_map]
# Default implementation ignores context
yield from self._annotate_images(images)
def _get_surrounding_text(
self, doc: DoclingDocument, picture_item: PictureItem
) -> str:
"""Get text context from items before and after the picture."""
context = []
text_items_before_picture, text_items_after_picture = [], []
found_picture = False
after_count = 0
_log.debug(
"Getting surrounding text for picture ref: %s", picture_item.self_ref
)
for item, _ in doc.iterate_items():
if item == picture_item:
found_picture = True
continue
if not found_picture: # before picture
if isinstance(item, (str, NodeItem)):
text = item if isinstance(item, str) else getattr(item, "text", "")
if text and text.strip():
# hold all text items before the picture
text_items_before_picture.append(text)
else: # after picture
if (
isinstance(item, (str, NodeItem))
and after_count
< self.options.text_context_window_size_after_picture
):
text = item if isinstance(item, str) else getattr(item, "text", "")
if text and text.strip():
text_items_after_picture.append(text)
after_count += 1
if after_count >= self.options.text_context_window_size_after_picture:
# Stop if we have reached the limit of text items after the picture
break
# Combine text items before and after the picture
if self.options.text_context_window_size_before_picture > 0:
# get only the last N text items before the picture
context.extend(
text_items_before_picture[
-self.options.text_context_window_size_before_picture :
]
)
if self.options.text_context_window_size_after_picture > 0:
context.extend(text_items_after_picture)
_log.debug("Context before picture: %s", text_items_before_picture)
_log.debug("Context after picture: %s", text_items_after_picture)
# Join the context with newlines
return "\n".join(context)
def __call__(
self,
doc: DoclingDocument,
@ -58,8 +125,9 @@ class PictureDescriptionBaseModel(
yield element.item
return
images: List[Image.Image] = []
elements: List[PictureItem] = []
image_context_map: List[Tuple[Image.Image, str]] = []
pictures: List[PictureItem] = []
for el in element_batch:
assert isinstance(el.item, PictureItem)
describe_image = True
@ -74,16 +142,31 @@ class PictureDescriptionBaseModel(
if area_fraction < self.options.picture_area_threshold:
describe_image = False
if describe_image:
elements.append(el.item)
images.append(el.image)
pictures.append(el.item)
context = ""
if (
self.options.text_context_window_size_before_picture > 0
or self.options.text_context_window_size_after_picture > 0
):
# Get the surrounding text context
context = self._get_surrounding_text(doc, el.item)
image_context_map.append((el.image, context))
outputs = self._annotate_images(images)
for item, output in zip(elements, outputs):
item.annotations.append(
PictureDescriptionData(text=output, provenance=self.provenance)
if (
self.options.text_context_window_size_before_picture > 0
or self.options.text_context_window_size_after_picture > 0
):
picture_descriptions = self._annotate_with_context(image_context_map)
else:
picture_descriptions = self._annotate_images(
image for image, _ in image_context_map
)
yield item
for picture, description in zip(pictures, picture_descriptions):
picture.annotations.append(
PictureDescriptionData(text=description, provenance=self.provenance)
)
yield picture
@classmethod
@abstractmethod

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Optional, Type, Union
from typing import Optional, Tuple, Type, Union
from PIL import Image
@ -12,6 +13,8 @@ from docling.datamodel.pipeline_options import (
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.utils.accelerator_utils import decide_device
_log = logging.getLogger(__name__)
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
@classmethod
@ -121,3 +124,42 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
)
yield generated_texts[0].strip()
def _annotate_with_context(
self, image_context_map: Iterable[Tuple[Image.Image, str]]
) -> Iterable[str]:
from transformers import GenerationConfig
for image, context in image_context_map:
# Create input messages with context
context_prompt = f"{context}\n{self.options.prompt}"
_log.debug("Prompt: %s", context_prompt)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": context_prompt},
],
},
]
# Prepare inputs
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=True
)
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(self.device)
# Generate outputs
generated_ids = self.model.generate(
**inputs,
generation_config=GenerationConfig(**self.options.generation_config),
)
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)
yield generated_texts[0].strip()

View File

@ -283,3 +283,9 @@ branch = "main"
parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test"
parser_angular_minor_types = "feat"
parser_angular_patch_types = "fix,perf"
[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"

View File

@ -0,0 +1,150 @@
import logging
from pathlib import Path
import pytest
import requests
from docling_core.types.doc.document import PictureDescriptionData
from docling.datamodel.base_models import ConversionStatus, InputFormat
from docling.datamodel.pipeline_options import (
PdfPipelineOptions,
PictureDescriptionApiOptions,
PictureDescriptionVlmOptions,
)
from docling.document_converter import DocumentConverter, PdfFormatOption
# Configure logging at the top of the file
logging.basicConfig(level=logging.DEBUG)
_log = logging.getLogger(__name__)
IMAGE_RESOLUTION_SCALE = 2.0
LOCAL_VISION_MODEL = "ibm-granite/granite-vision-3.2-2b"
# LOCAL_VISION_MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct"
API_VISION_MODEL = "granite3.2-vision:2b"
REMOTE_CHAT_API_URL = "http://localhost:8321/v1/openai/v1/chat/completions" # for llama-stack OpenAI API interface
DOC_SOURCE = "https://www.allspringglobal.com/globalassets/assets/regulatory/summary-prospectus/emerging-markets-equity-summ.pdf"
PROMPT = (
"Please describe the image using the text above as additional context. "
"Additionally, if only the image contains a chart (like bar chat, pie chat, line chat, etc.), "
"please try to extract a list of data points (percentages, numbers, etc) that are depicted in the chart. "
"Also, based on the type of information extracted, "
"when applicable try to summarize it using bullet points or even a tabular representation using markdown if possible."
)
def is_api_available(url: str, timeout: int = 3) -> bool:
try:
requests.get(url, timeout=timeout)
return True
except (requests.ConnectionError, requests.Timeout) as e:
_log.debug(f"API endpoint {url} is not reachable: {e!s}")
return False
def process_document(pipeline_options: PdfPipelineOptions):
# Initialize document converter
doc_converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
}
)
# Convert test document
_log.info(f"Converting {DOC_SOURCE} with VLM API")
conversion_result = doc_converter.convert(source=DOC_SOURCE)
# Basic conversion checks
assert conversion_result.status == ConversionStatus.SUCCESS
doc = conversion_result.document
assert doc is not None
# Verify pictures were processed
assert len(doc.pictures) > 0
# Check each picture for descriptions
for picture in doc.pictures:
# Not every picture has a annotations (eg. some pictures are too small (based on the threshold param (5% of the page area by default))
# and gets ignored by the conversion Pipeline)
if len(picture.annotations) > 0:
# Get the description
descriptions = [
ann
for ann in picture.annotations
if isinstance(ann, PictureDescriptionData)
]
assert len(descriptions) > 0
# Verify each description is non-empty
for desc in descriptions:
assert isinstance(desc.text, str)
assert len(desc.text) > 0
_log.info(
f"\nPicture ref: {picture.get_ref().cref}, page #{picture.prov[0].page_no}"
)
_log.info(f"\tGenerated description: {desc.text}")
else:
_log.info(
f"Picture {picture.get_ref().cref} has no annotations (too small?)"
)
@pytest.mark.skipif(
not is_api_available(REMOTE_CHAT_API_URL),
reason="Remote API endpoint is not accessible",
)
def test_picture_description_context_api_integration():
"""Test that the context windows functionality works correctly in the picture description pipeline using a VLM served via API"""
# Setup pipeline options with context windows
pipeline_options = PdfPipelineOptions(
images_scale=IMAGE_RESOLUTION_SCALE,
do_picture_description=True,
generate_picture_images=True,
enable_remote_services=True,
picture_description_options=PictureDescriptionApiOptions(
url=REMOTE_CHAT_API_URL,
params=dict(model=API_VISION_MODEL),
text_context_window_size_before_picture=2, # Get 2 text items before
text_context_window_size_after_picture=1, # Get 1 text item after
prompt=PROMPT,
timeout=90,
),
)
process_document(pipeline_options)
def test_picture_description_context_vlm_integration():
"""Test that the context windows functionality works correctly in the picture description pipeline"""
# Setup pipeline options with context windows
pipeline_options = PdfPipelineOptions(
images_scale=IMAGE_RESOLUTION_SCALE,
generate_page_images=True,
do_picture_description=True,
generate_picture_images=True,
picture_description_options=PictureDescriptionVlmOptions(
repo_id=LOCAL_VISION_MODEL,
text_context_window_size_before_picture=2, # Get 2 text items before
text_context_window_size_after_picture=1, # Get 1 text item after
prompt=PROMPT,
),
)
process_document(pipeline_options)
def test_picture_description_no_context_vlm_integration():
"""Test that the picture description works without context windows"""
# Setup pipeline options without context windows
pipeline_options = PdfPipelineOptions(
images_scale=IMAGE_RESOLUTION_SCALE,
do_picture_description=True,
generate_picture_images=True,
picture_description_options=PictureDescriptionVlmOptions(
repo_id=LOCAL_VISION_MODEL,
text_context_window_size_before_picture=0, # No text context
text_context_window_size_after_picture=0, # No text context
prompt=PROMPT,
),
)
process_document(pipeline_options)