This commit is contained in:
Rafael Torres Coelho Soares (aka Tuelho) 2025-05-14 13:06:24 +00:00 committed by GitHub
commit 7bc9e6f963
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 322 additions and 13 deletions

View File

@ -212,6 +212,12 @@ class OcrMacOptions(OcrOptions):
class PictureDescriptionBaseOptions(BaseOptions): class PictureDescriptionBaseOptions(BaseOptions):
batch_size: int = 8 batch_size: int = 8
scale: float = 2 scale: float = 2
text_context_window_size_before_picture: int = (
1 # Number of text items to consider before the image
)
text_context_window_size_after_picture: int = (
0 # Number of text items to consider after the image
)
picture_area_threshold: float = ( picture_area_threshold: float = (
0.05 # percentage of the area for a picture to processed with the models 0.05 # percentage of the area for a picture to processed with the models

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Optional, Type, Union from typing import Optional, Tuple, Type, Union
from PIL import Image from PIL import Image
@ -13,6 +14,8 @@ from docling.exceptions import OperationNotAllowed
from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.utils.api_image_request import api_image_request from docling.utils.api_image_request import api_image_request
_log = logging.getLogger(__name__)
class PictureDescriptionApiModel(PictureDescriptionBaseModel): class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# elements_batch_size = 4 # elements_batch_size = 4
@ -57,3 +60,22 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
headers=self.options.headers, headers=self.options.headers,
**self.options.params, **self.options.params,
) )
def _annotate_with_context(
self, image_context_map: Iterable[Tuple[Image.Image, str]]
) -> Iterable[str]:
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
for image, context in image_context_map:
# Create context-aware prompt
context_prompt = f"{context}\n{self.options.prompt}"
_log.debug("Prompt: %s", context_prompt)
yield api_image_request(
image=image,
prompt=context_prompt,
url=self.options.url,
timeout=self.options.timeout,
headers=self.options.headers,
**self.options.params,
)

View File

@ -1,7 +1,8 @@
import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import List, Optional, Type, Union from typing import List, Optional, Tuple, Type, Union
from docling_core.types.doc import ( from docling_core.types.doc import (
DoclingDocument, DoclingDocument,
@ -23,6 +24,8 @@ from docling.models.base_model import (
ItemAndImageEnrichmentElement, ItemAndImageEnrichmentElement,
) )
_log = logging.getLogger(__name__)
class PictureDescriptionBaseModel( class PictureDescriptionBaseModel(
BaseItemAndImageEnrichmentModel, BaseModelWithOptions BaseItemAndImageEnrichmentModel, BaseModelWithOptions
@ -48,6 +51,70 @@ class PictureDescriptionBaseModel(
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
raise NotImplementedError raise NotImplementedError
def _annotate_with_context(
self, image_context_map: Iterable[Tuple[Image.Image, str]]
) -> Iterable[str]:
"""Override this method to support context in concrete implementations."""
# Extract only the images (keys) from the dict
images = [image_context_pair[0] for image_context_pair in image_context_map]
# Default implementation ignores context
yield from self._annotate_images(images)
def _get_surrounding_text(
self, doc: DoclingDocument, picture_item: PictureItem
) -> str:
"""Get text context from items before and after the picture."""
context = []
text_items_before_picture, text_items_after_picture = [], []
found_picture = False
after_count = 0
_log.debug(
"Getting surrounding text for picture ref: %s", picture_item.self_ref
)
for item, _ in doc.iterate_items():
if item == picture_item:
found_picture = True
continue
if not found_picture: # before picture
if isinstance(item, (str, NodeItem)):
text = item if isinstance(item, str) else getattr(item, "text", "")
if text and text.strip():
# hold all text items before the picture
text_items_before_picture.append(text)
else: # after picture
if (
isinstance(item, (str, NodeItem))
and after_count
< self.options.text_context_window_size_after_picture
):
text = item if isinstance(item, str) else getattr(item, "text", "")
if text and text.strip():
text_items_after_picture.append(text)
after_count += 1
if after_count >= self.options.text_context_window_size_after_picture:
# Stop if we have reached the limit of text items after the picture
break
# Combine text items before and after the picture
if self.options.text_context_window_size_before_picture > 0:
# get only the last N text items before the picture
context.extend(
text_items_before_picture[
-self.options.text_context_window_size_before_picture :
]
)
if self.options.text_context_window_size_after_picture > 0:
context.extend(text_items_after_picture)
_log.debug("Context before picture: %s", text_items_before_picture)
_log.debug("Context after picture: %s", text_items_after_picture)
# Join the context with newlines
return "\n".join(context)
def __call__( def __call__(
self, self,
doc: DoclingDocument, doc: DoclingDocument,
@ -58,8 +125,9 @@ class PictureDescriptionBaseModel(
yield element.item yield element.item
return return
images: List[Image.Image] = [] image_context_map: List[Tuple[Image.Image, str]] = []
elements: List[PictureItem] = [] pictures: List[PictureItem] = []
for el in element_batch: for el in element_batch:
assert isinstance(el.item, PictureItem) assert isinstance(el.item, PictureItem)
describe_image = True describe_image = True
@ -74,16 +142,31 @@ class PictureDescriptionBaseModel(
if area_fraction < self.options.picture_area_threshold: if area_fraction < self.options.picture_area_threshold:
describe_image = False describe_image = False
if describe_image: if describe_image:
elements.append(el.item) pictures.append(el.item)
images.append(el.image) context = ""
if (
self.options.text_context_window_size_before_picture > 0
or self.options.text_context_window_size_after_picture > 0
):
# Get the surrounding text context
context = self._get_surrounding_text(doc, el.item)
image_context_map.append((el.image, context))
outputs = self._annotate_images(images) if (
self.options.text_context_window_size_before_picture > 0
for item, output in zip(elements, outputs): or self.options.text_context_window_size_after_picture > 0
item.annotations.append( ):
PictureDescriptionData(text=output, provenance=self.provenance) picture_descriptions = self._annotate_with_context(image_context_map)
else:
picture_descriptions = self._annotate_images(
image for image, _ in image_context_map
) )
yield item
for picture, description in zip(pictures, picture_descriptions):
picture.annotations.append(
PictureDescriptionData(text=description, provenance=self.provenance)
)
yield picture
@classmethod @classmethod
@abstractmethod @abstractmethod

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Optional, Type, Union from typing import Optional, Tuple, Type, Union
from PIL import Image from PIL import Image
@ -12,6 +13,8 @@ from docling.datamodel.pipeline_options import (
from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.utils.accelerator_utils import decide_device from docling.utils.accelerator_utils import decide_device
_log = logging.getLogger(__name__)
class PictureDescriptionVlmModel(PictureDescriptionBaseModel): class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
@classmethod @classmethod
@ -121,3 +124,42 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
) )
yield generated_texts[0].strip() yield generated_texts[0].strip()
def _annotate_with_context(
self, image_context_map: Iterable[Tuple[Image.Image, str]]
) -> Iterable[str]:
from transformers import GenerationConfig
for image, context in image_context_map:
# Create input messages with context
context_prompt = f"{context}\n{self.options.prompt}"
_log.debug("Prompt: %s", context_prompt)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": context_prompt},
],
},
]
# Prepare inputs
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=True
)
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(self.device)
# Generate outputs
generated_ids = self.model.generate(
**inputs,
generation_config=GenerationConfig(**self.options.generation_config),
)
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)
yield generated_texts[0].strip()

View File

@ -283,3 +283,9 @@ branch = "main"
parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test" parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test"
parser_angular_minor_types = "feat" parser_angular_minor_types = "feat"
parser_angular_patch_types = "fix,perf" parser_angular_patch_types = "fix,perf"
[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"

View File

@ -0,0 +1,150 @@
import logging
from pathlib import Path
import pytest
import requests
from docling_core.types.doc.document import PictureDescriptionData
from docling.datamodel.base_models import ConversionStatus, InputFormat
from docling.datamodel.pipeline_options import (
PdfPipelineOptions,
PictureDescriptionApiOptions,
PictureDescriptionVlmOptions,
)
from docling.document_converter import DocumentConverter, PdfFormatOption
# Configure logging at the top of the file
logging.basicConfig(level=logging.DEBUG)
_log = logging.getLogger(__name__)
IMAGE_RESOLUTION_SCALE = 2.0
LOCAL_VISION_MODEL = "ibm-granite/granite-vision-3.2-2b"
# LOCAL_VISION_MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct"
API_VISION_MODEL = "granite3.2-vision:2b"
REMOTE_CHAT_API_URL = "http://localhost:8321/v1/openai/v1/chat/completions" # for llama-stack OpenAI API interface
DOC_SOURCE = "https://www.allspringglobal.com/globalassets/assets/regulatory/summary-prospectus/emerging-markets-equity-summ.pdf"
PROMPT = (
"Please describe the image using the text above as additional context. "
"Additionally, if only the image contains a chart (like bar chat, pie chat, line chat, etc.), "
"please try to extract a list of data points (percentages, numbers, etc) that are depicted in the chart. "
"Also, based on the type of information extracted, "
"when applicable try to summarize it using bullet points or even a tabular representation using markdown if possible."
)
def is_api_available(url: str, timeout: int = 3) -> bool:
try:
requests.get(url, timeout=timeout)
return True
except (requests.ConnectionError, requests.Timeout) as e:
_log.debug(f"API endpoint {url} is not reachable: {e!s}")
return False
def process_document(pipeline_options: PdfPipelineOptions):
# Initialize document converter
doc_converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
}
)
# Convert test document
_log.info(f"Converting {DOC_SOURCE} with VLM API")
conversion_result = doc_converter.convert(source=DOC_SOURCE)
# Basic conversion checks
assert conversion_result.status == ConversionStatus.SUCCESS
doc = conversion_result.document
assert doc is not None
# Verify pictures were processed
assert len(doc.pictures) > 0
# Check each picture for descriptions
for picture in doc.pictures:
# Not every picture has a annotations (eg. some pictures are too small (based on the threshold param (5% of the page area by default))
# and gets ignored by the conversion Pipeline)
if len(picture.annotations) > 0:
# Get the description
descriptions = [
ann
for ann in picture.annotations
if isinstance(ann, PictureDescriptionData)
]
assert len(descriptions) > 0
# Verify each description is non-empty
for desc in descriptions:
assert isinstance(desc.text, str)
assert len(desc.text) > 0
_log.info(
f"\nPicture ref: {picture.get_ref().cref}, page #{picture.prov[0].page_no}"
)
_log.info(f"\tGenerated description: {desc.text}")
else:
_log.info(
f"Picture {picture.get_ref().cref} has no annotations (too small?)"
)
@pytest.mark.skipif(
not is_api_available(REMOTE_CHAT_API_URL),
reason="Remote API endpoint is not accessible",
)
def test_picture_description_context_api_integration():
"""Test that the context windows functionality works correctly in the picture description pipeline using a VLM served via API"""
# Setup pipeline options with context windows
pipeline_options = PdfPipelineOptions(
images_scale=IMAGE_RESOLUTION_SCALE,
do_picture_description=True,
generate_picture_images=True,
enable_remote_services=True,
picture_description_options=PictureDescriptionApiOptions(
url=REMOTE_CHAT_API_URL,
params=dict(model=API_VISION_MODEL),
text_context_window_size_before_picture=2, # Get 2 text items before
text_context_window_size_after_picture=1, # Get 1 text item after
prompt=PROMPT,
timeout=90,
),
)
process_document(pipeline_options)
def test_picture_description_context_vlm_integration():
"""Test that the context windows functionality works correctly in the picture description pipeline"""
# Setup pipeline options with context windows
pipeline_options = PdfPipelineOptions(
images_scale=IMAGE_RESOLUTION_SCALE,
generate_page_images=True,
do_picture_description=True,
generate_picture_images=True,
picture_description_options=PictureDescriptionVlmOptions(
repo_id=LOCAL_VISION_MODEL,
text_context_window_size_before_picture=2, # Get 2 text items before
text_context_window_size_after_picture=1, # Get 1 text item after
prompt=PROMPT,
),
)
process_document(pipeline_options)
def test_picture_description_no_context_vlm_integration():
"""Test that the picture description works without context windows"""
# Setup pipeline options without context windows
pipeline_options = PdfPipelineOptions(
images_scale=IMAGE_RESOLUTION_SCALE,
do_picture_description=True,
generate_picture_images=True,
picture_description_options=PictureDescriptionVlmOptions(
repo_id=LOCAL_VISION_MODEL,
text_context_window_size_before_picture=0, # No text context
text_context_window_size_after_picture=0, # No text context
prompt=PROMPT,
),
)
process_document(pipeline_options)