mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
feat: Picture description using context with surrounding text
- Add the ability to use text items surrounding the picture as context to prompt the VLM. - Implemented VLM-based picture description functionality - Added ability to use text before and after pictures as context - Added tests for both context and non-context approaches - Included formatting fixes Signed-off-by: Rafael T. C. Soares <rafaelcba@gmail.com>
This commit is contained in:
parent
23238c241f
commit
d7922ab31d
@ -212,6 +212,12 @@ class OcrMacOptions(OcrOptions):
|
|||||||
class PictureDescriptionBaseOptions(BaseOptions):
|
class PictureDescriptionBaseOptions(BaseOptions):
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
scale: float = 2
|
scale: float = 2
|
||||||
|
text_context_window_size_before_picture: int = (
|
||||||
|
1 # Number of text items to consider before the image
|
||||||
|
)
|
||||||
|
text_context_window_size_after_picture: int = (
|
||||||
|
0 # Number of text items to consider after the image
|
||||||
|
)
|
||||||
|
|
||||||
picture_area_threshold: float = (
|
picture_area_threshold: float = (
|
||||||
0.05 # percentage of the area for a picture to processed with the models
|
0.05 # percentage of the area for a picture to processed with the models
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
import logging
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -13,6 +14,8 @@ from docling.exceptions import OperationNotAllowed
|
|||||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||||
from docling.utils.api_image_request import api_image_request
|
from docling.utils.api_image_request import api_image_request
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||||
# elements_batch_size = 4
|
# elements_batch_size = 4
|
||||||
@ -57,3 +60,22 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
|||||||
headers=self.options.headers,
|
headers=self.options.headers,
|
||||||
**self.options.params,
|
**self.options.params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _annotate_with_context(
|
||||||
|
self, image_context_map: Iterable[Tuple[Image.Image, str]]
|
||||||
|
) -> Iterable[str]:
|
||||||
|
# Note: technically we could make a batch request here,
|
||||||
|
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
||||||
|
for image, context in image_context_map:
|
||||||
|
# Create context-aware prompt
|
||||||
|
context_prompt = f"{context}\n{self.options.prompt}"
|
||||||
|
_log.debug("Prompt: %s", context_prompt)
|
||||||
|
|
||||||
|
yield api_image_request(
|
||||||
|
image=image,
|
||||||
|
prompt=context_prompt,
|
||||||
|
url=self.options.url,
|
||||||
|
timeout=self.options.timeout,
|
||||||
|
headers=self.options.headers,
|
||||||
|
**self.options.params,
|
||||||
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Type, Union
|
from typing import List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from docling_core.types.doc import (
|
from docling_core.types.doc import (
|
||||||
DoclingDocument,
|
DoclingDocument,
|
||||||
@ -23,6 +24,8 @@ from docling.models.base_model import (
|
|||||||
ItemAndImageEnrichmentElement,
|
ItemAndImageEnrichmentElement,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PictureDescriptionBaseModel(
|
class PictureDescriptionBaseModel(
|
||||||
BaseItemAndImageEnrichmentModel, BaseModelWithOptions
|
BaseItemAndImageEnrichmentModel, BaseModelWithOptions
|
||||||
@ -48,6 +51,70 @@ class PictureDescriptionBaseModel(
|
|||||||
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _annotate_with_context(
|
||||||
|
self, image_context_map: Iterable[Tuple[Image.Image, str]]
|
||||||
|
) -> Iterable[str]:
|
||||||
|
"""Override this method to support context in concrete implementations."""
|
||||||
|
# Extract only the images (keys) from the dict
|
||||||
|
images = [image_context_pair[0] for image_context_pair in image_context_map]
|
||||||
|
# Default implementation ignores context
|
||||||
|
yield from self._annotate_images(images)
|
||||||
|
|
||||||
|
def _get_surrounding_text(
|
||||||
|
self, doc: DoclingDocument, picture_item: PictureItem
|
||||||
|
) -> str:
|
||||||
|
"""Get text context from items before and after the picture."""
|
||||||
|
context = []
|
||||||
|
text_items_before_picture, text_items_after_picture = [], []
|
||||||
|
found_picture = False
|
||||||
|
after_count = 0
|
||||||
|
|
||||||
|
_log.debug(
|
||||||
|
"Getting surrounding text for picture ref: %s", picture_item.self_ref
|
||||||
|
)
|
||||||
|
for item, _ in doc.iterate_items():
|
||||||
|
if item == picture_item:
|
||||||
|
found_picture = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not found_picture: # before picture
|
||||||
|
if isinstance(item, (str, NodeItem)):
|
||||||
|
text = item if isinstance(item, str) else getattr(item, "text", "")
|
||||||
|
if text and text.strip():
|
||||||
|
# hold all text items before the picture
|
||||||
|
text_items_before_picture.append(text)
|
||||||
|
else: # after picture
|
||||||
|
if (
|
||||||
|
isinstance(item, (str, NodeItem))
|
||||||
|
and after_count
|
||||||
|
< self.options.text_context_window_size_after_picture
|
||||||
|
):
|
||||||
|
text = item if isinstance(item, str) else getattr(item, "text", "")
|
||||||
|
if text and text.strip():
|
||||||
|
text_items_after_picture.append(text)
|
||||||
|
after_count += 1
|
||||||
|
|
||||||
|
if after_count >= self.options.text_context_window_size_after_picture:
|
||||||
|
# Stop if we have reached the limit of text items after the picture
|
||||||
|
break
|
||||||
|
|
||||||
|
# Combine text items before and after the picture
|
||||||
|
if self.options.text_context_window_size_before_picture > 0:
|
||||||
|
# get only the last N text items before the picture
|
||||||
|
context.extend(
|
||||||
|
text_items_before_picture[
|
||||||
|
-self.options.text_context_window_size_before_picture :
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.options.text_context_window_size_after_picture > 0:
|
||||||
|
context.extend(text_items_after_picture)
|
||||||
|
|
||||||
|
_log.debug("Context before picture: %s", text_items_before_picture)
|
||||||
|
_log.debug("Context after picture: %s", text_items_after_picture)
|
||||||
|
# Join the context with newlines
|
||||||
|
return "\n".join(context)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
doc: DoclingDocument,
|
doc: DoclingDocument,
|
||||||
@ -58,8 +125,9 @@ class PictureDescriptionBaseModel(
|
|||||||
yield element.item
|
yield element.item
|
||||||
return
|
return
|
||||||
|
|
||||||
images: List[Image.Image] = []
|
image_context_map: List[Tuple[Image.Image, str]] = []
|
||||||
elements: List[PictureItem] = []
|
pictures: List[PictureItem] = []
|
||||||
|
|
||||||
for el in element_batch:
|
for el in element_batch:
|
||||||
assert isinstance(el.item, PictureItem)
|
assert isinstance(el.item, PictureItem)
|
||||||
describe_image = True
|
describe_image = True
|
||||||
@ -74,16 +142,31 @@ class PictureDescriptionBaseModel(
|
|||||||
if area_fraction < self.options.picture_area_threshold:
|
if area_fraction < self.options.picture_area_threshold:
|
||||||
describe_image = False
|
describe_image = False
|
||||||
if describe_image:
|
if describe_image:
|
||||||
elements.append(el.item)
|
pictures.append(el.item)
|
||||||
images.append(el.image)
|
context = ""
|
||||||
|
if (
|
||||||
|
self.options.text_context_window_size_before_picture > 0
|
||||||
|
or self.options.text_context_window_size_after_picture > 0
|
||||||
|
):
|
||||||
|
# Get the surrounding text context
|
||||||
|
context = self._get_surrounding_text(doc, el.item)
|
||||||
|
image_context_map.append((el.image, context))
|
||||||
|
|
||||||
outputs = self._annotate_images(images)
|
if (
|
||||||
|
self.options.text_context_window_size_before_picture > 0
|
||||||
for item, output in zip(elements, outputs):
|
or self.options.text_context_window_size_after_picture > 0
|
||||||
item.annotations.append(
|
):
|
||||||
PictureDescriptionData(text=output, provenance=self.provenance)
|
picture_descriptions = self._annotate_with_context(image_context_map)
|
||||||
|
else:
|
||||||
|
picture_descriptions = self._annotate_images(
|
||||||
|
image for image, _ in image_context_map
|
||||||
)
|
)
|
||||||
yield item
|
|
||||||
|
for picture, description in zip(pictures, picture_descriptions):
|
||||||
|
picture.annotations.append(
|
||||||
|
PictureDescriptionData(text=description, provenance=self.provenance)
|
||||||
|
)
|
||||||
|
yield picture
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
import logging
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -12,6 +13,8 @@ from docling.datamodel.pipeline_options import (
|
|||||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||||
from docling.utils.accelerator_utils import decide_device
|
from docling.utils.accelerator_utils import decide_device
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -121,3 +124,42 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield generated_texts[0].strip()
|
yield generated_texts[0].strip()
|
||||||
|
|
||||||
|
def _annotate_with_context(
|
||||||
|
self, image_context_map: Iterable[Tuple[Image.Image, str]]
|
||||||
|
) -> Iterable[str]:
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
|
for image, context in image_context_map:
|
||||||
|
# Create input messages with context
|
||||||
|
context_prompt = f"{context}\n{self.options.prompt}"
|
||||||
|
_log.debug("Prompt: %s", context_prompt)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": context_prompt},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
prompt = self.processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
|
||||||
|
inputs = inputs.to(self.device)
|
||||||
|
|
||||||
|
# Generate outputs
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
generation_config=GenerationConfig(**self.options.generation_config),
|
||||||
|
)
|
||||||
|
generated_texts = self.processor.batch_decode(
|
||||||
|
generated_ids[:, inputs["input_ids"].shape[1] :],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield generated_texts[0].strip()
|
||||||
|
@ -283,3 +283,9 @@ branch = "main"
|
|||||||
parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test"
|
parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test"
|
||||||
parser_angular_minor_types = "feat"
|
parser_angular_minor_types = "feat"
|
||||||
parser_angular_patch_types = "fix,perf"
|
parser_angular_patch_types = "fix,perf"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
log_cli = true
|
||||||
|
log_cli_level = "INFO"
|
||||||
|
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
||||||
|
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
150
tests/test_picture_description.py
Normal file
150
tests/test_picture_description.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from docling_core.types.doc.document import PictureDescriptionData
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import ConversionStatus, InputFormat
|
||||||
|
from docling.datamodel.pipeline_options import (
|
||||||
|
PdfPipelineOptions,
|
||||||
|
PictureDescriptionApiOptions,
|
||||||
|
PictureDescriptionVlmOptions,
|
||||||
|
)
|
||||||
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||||
|
|
||||||
|
# Configure logging at the top of the file
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
IMAGE_RESOLUTION_SCALE = 2.0
|
||||||
|
LOCAL_VISION_MODEL = "ibm-granite/granite-vision-3.2-2b"
|
||||||
|
# LOCAL_VISION_MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct"
|
||||||
|
API_VISION_MODEL = "granite3.2-vision:2b"
|
||||||
|
REMOTE_CHAT_API_URL = "http://localhost:8321/v1/openai/v1/chat/completions" # for llama-stack OpenAI API interface
|
||||||
|
DOC_SOURCE = "https://www.allspringglobal.com/globalassets/assets/regulatory/summary-prospectus/emerging-markets-equity-summ.pdf"
|
||||||
|
PROMPT = (
|
||||||
|
"Please describe the image using the text above as additional context. "
|
||||||
|
"Additionally, if only the image contains a chart (like bar chat, pie chat, line chat, etc.), "
|
||||||
|
"please try to extract a list of data points (percentages, numbers, etc) that are depicted in the chart. "
|
||||||
|
"Also, based on the type of information extracted, "
|
||||||
|
"when applicable try to summarize it using bullet points or even a tabular representation using markdown if possible."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_api_available(url: str, timeout: int = 3) -> bool:
|
||||||
|
try:
|
||||||
|
requests.get(url, timeout=timeout)
|
||||||
|
return True
|
||||||
|
except (requests.ConnectionError, requests.Timeout) as e:
|
||||||
|
_log.debug(f"API endpoint {url} is not reachable: {e!s}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def process_document(pipeline_options: PdfPipelineOptions):
|
||||||
|
# Initialize document converter
|
||||||
|
doc_converter = DocumentConverter(
|
||||||
|
format_options={
|
||||||
|
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert test document
|
||||||
|
_log.info(f"Converting {DOC_SOURCE} with VLM API")
|
||||||
|
conversion_result = doc_converter.convert(source=DOC_SOURCE)
|
||||||
|
|
||||||
|
# Basic conversion checks
|
||||||
|
assert conversion_result.status == ConversionStatus.SUCCESS
|
||||||
|
doc = conversion_result.document
|
||||||
|
assert doc is not None
|
||||||
|
|
||||||
|
# Verify pictures were processed
|
||||||
|
assert len(doc.pictures) > 0
|
||||||
|
|
||||||
|
# Check each picture for descriptions
|
||||||
|
for picture in doc.pictures:
|
||||||
|
# Not every picture has a annotations (eg. some pictures are too small (based on the threshold param (5% of the page area by default))
|
||||||
|
# and gets ignored by the conversion Pipeline)
|
||||||
|
if len(picture.annotations) > 0:
|
||||||
|
# Get the description
|
||||||
|
descriptions = [
|
||||||
|
ann
|
||||||
|
for ann in picture.annotations
|
||||||
|
if isinstance(ann, PictureDescriptionData)
|
||||||
|
]
|
||||||
|
assert len(descriptions) > 0
|
||||||
|
|
||||||
|
# Verify each description is non-empty
|
||||||
|
for desc in descriptions:
|
||||||
|
assert isinstance(desc.text, str)
|
||||||
|
assert len(desc.text) > 0
|
||||||
|
_log.info(
|
||||||
|
f"\nPicture ref: {picture.get_ref().cref}, page #{picture.prov[0].page_no}"
|
||||||
|
)
|
||||||
|
_log.info(f"\tGenerated description: {desc.text}")
|
||||||
|
else:
|
||||||
|
_log.info(
|
||||||
|
f"Picture {picture.get_ref().cref} has no annotations (too small?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_api_available(REMOTE_CHAT_API_URL),
|
||||||
|
reason="Remote API endpoint is not accessible",
|
||||||
|
)
|
||||||
|
def test_picture_description_context_api_integration():
|
||||||
|
"""Test that the context windows functionality works correctly in the picture description pipeline using a VLM served via API"""
|
||||||
|
# Setup pipeline options with context windows
|
||||||
|
pipeline_options = PdfPipelineOptions(
|
||||||
|
images_scale=IMAGE_RESOLUTION_SCALE,
|
||||||
|
do_picture_description=True,
|
||||||
|
generate_picture_images=True,
|
||||||
|
enable_remote_services=True,
|
||||||
|
picture_description_options=PictureDescriptionApiOptions(
|
||||||
|
url=REMOTE_CHAT_API_URL,
|
||||||
|
params=dict(model=API_VISION_MODEL),
|
||||||
|
text_context_window_size_before_picture=2, # Get 2 text items before
|
||||||
|
text_context_window_size_after_picture=1, # Get 1 text item after
|
||||||
|
prompt=PROMPT,
|
||||||
|
timeout=90,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
process_document(pipeline_options)
|
||||||
|
|
||||||
|
|
||||||
|
def test_picture_description_context_vlm_integration():
|
||||||
|
"""Test that the context windows functionality works correctly in the picture description pipeline"""
|
||||||
|
# Setup pipeline options with context windows
|
||||||
|
pipeline_options = PdfPipelineOptions(
|
||||||
|
images_scale=IMAGE_RESOLUTION_SCALE,
|
||||||
|
generate_page_images=True,
|
||||||
|
do_picture_description=True,
|
||||||
|
generate_picture_images=True,
|
||||||
|
picture_description_options=PictureDescriptionVlmOptions(
|
||||||
|
repo_id=LOCAL_VISION_MODEL,
|
||||||
|
text_context_window_size_before_picture=2, # Get 2 text items before
|
||||||
|
text_context_window_size_after_picture=1, # Get 1 text item after
|
||||||
|
prompt=PROMPT,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
process_document(pipeline_options)
|
||||||
|
|
||||||
|
|
||||||
|
def test_picture_description_no_context_vlm_integration():
|
||||||
|
"""Test that the picture description works without context windows"""
|
||||||
|
# Setup pipeline options without context windows
|
||||||
|
pipeline_options = PdfPipelineOptions(
|
||||||
|
images_scale=IMAGE_RESOLUTION_SCALE,
|
||||||
|
do_picture_description=True,
|
||||||
|
generate_picture_images=True,
|
||||||
|
picture_description_options=PictureDescriptionVlmOptions(
|
||||||
|
repo_id=LOCAL_VISION_MODEL,
|
||||||
|
text_context_window_size_before_picture=0, # No text context
|
||||||
|
text_context_window_size_after_picture=0, # No text context
|
||||||
|
prompt=PROMPT,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
process_document(pipeline_options)
|
Loading…
Reference in New Issue
Block a user