mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
Merge d7922ab31d
into 9f8b479f17
This commit is contained in:
commit
7bc9e6f963
@ -212,6 +212,12 @@ class OcrMacOptions(OcrOptions):
|
||||
class PictureDescriptionBaseOptions(BaseOptions):
|
||||
batch_size: int = 8
|
||||
scale: float = 2
|
||||
text_context_window_size_before_picture: int = (
|
||||
1 # Number of text items to consider before the image
|
||||
)
|
||||
text_context_window_size_after_picture: int = (
|
||||
0 # Number of text items to consider after the image
|
||||
)
|
||||
|
||||
picture_area_threshold: float = (
|
||||
0.05 # percentage of the area for a picture to processed with the models
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type, Union
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@ -13,6 +14,8 @@ from docling.exceptions import OperationNotAllowed
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
from docling.utils.api_image_request import api_image_request
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
# elements_batch_size = 4
|
||||
@ -57,3 +60,22 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
headers=self.options.headers,
|
||||
**self.options.params,
|
||||
)
|
||||
|
||||
def _annotate_with_context(
|
||||
self, image_context_map: Iterable[Tuple[Image.Image, str]]
|
||||
) -> Iterable[str]:
|
||||
# Note: technically we could make a batch request here,
|
||||
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
||||
for image, context in image_context_map:
|
||||
# Create context-aware prompt
|
||||
context_prompt = f"{context}\n{self.options.prompt}"
|
||||
_log.debug("Prompt: %s", context_prompt)
|
||||
|
||||
yield api_image_request(
|
||||
image=image,
|
||||
prompt=context_prompt,
|
||||
url=self.options.url,
|
||||
timeout=self.options.timeout,
|
||||
headers=self.options.headers,
|
||||
**self.options.params,
|
||||
)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type, Union
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
from docling_core.types.doc import (
|
||||
DoclingDocument,
|
||||
@ -23,6 +24,8 @@ from docling.models.base_model import (
|
||||
ItemAndImageEnrichmentElement,
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PictureDescriptionBaseModel(
|
||||
BaseItemAndImageEnrichmentModel, BaseModelWithOptions
|
||||
@ -48,6 +51,70 @@ class PictureDescriptionBaseModel(
|
||||
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _annotate_with_context(
|
||||
self, image_context_map: Iterable[Tuple[Image.Image, str]]
|
||||
) -> Iterable[str]:
|
||||
"""Override this method to support context in concrete implementations."""
|
||||
# Extract only the images (keys) from the dict
|
||||
images = [image_context_pair[0] for image_context_pair in image_context_map]
|
||||
# Default implementation ignores context
|
||||
yield from self._annotate_images(images)
|
||||
|
||||
def _get_surrounding_text(
|
||||
self, doc: DoclingDocument, picture_item: PictureItem
|
||||
) -> str:
|
||||
"""Get text context from items before and after the picture."""
|
||||
context = []
|
||||
text_items_before_picture, text_items_after_picture = [], []
|
||||
found_picture = False
|
||||
after_count = 0
|
||||
|
||||
_log.debug(
|
||||
"Getting surrounding text for picture ref: %s", picture_item.self_ref
|
||||
)
|
||||
for item, _ in doc.iterate_items():
|
||||
if item == picture_item:
|
||||
found_picture = True
|
||||
continue
|
||||
|
||||
if not found_picture: # before picture
|
||||
if isinstance(item, (str, NodeItem)):
|
||||
text = item if isinstance(item, str) else getattr(item, "text", "")
|
||||
if text and text.strip():
|
||||
# hold all text items before the picture
|
||||
text_items_before_picture.append(text)
|
||||
else: # after picture
|
||||
if (
|
||||
isinstance(item, (str, NodeItem))
|
||||
and after_count
|
||||
< self.options.text_context_window_size_after_picture
|
||||
):
|
||||
text = item if isinstance(item, str) else getattr(item, "text", "")
|
||||
if text and text.strip():
|
||||
text_items_after_picture.append(text)
|
||||
after_count += 1
|
||||
|
||||
if after_count >= self.options.text_context_window_size_after_picture:
|
||||
# Stop if we have reached the limit of text items after the picture
|
||||
break
|
||||
|
||||
# Combine text items before and after the picture
|
||||
if self.options.text_context_window_size_before_picture > 0:
|
||||
# get only the last N text items before the picture
|
||||
context.extend(
|
||||
text_items_before_picture[
|
||||
-self.options.text_context_window_size_before_picture :
|
||||
]
|
||||
)
|
||||
|
||||
if self.options.text_context_window_size_after_picture > 0:
|
||||
context.extend(text_items_after_picture)
|
||||
|
||||
_log.debug("Context before picture: %s", text_items_before_picture)
|
||||
_log.debug("Context after picture: %s", text_items_after_picture)
|
||||
# Join the context with newlines
|
||||
return "\n".join(context)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
doc: DoclingDocument,
|
||||
@ -58,8 +125,9 @@ class PictureDescriptionBaseModel(
|
||||
yield element.item
|
||||
return
|
||||
|
||||
images: List[Image.Image] = []
|
||||
elements: List[PictureItem] = []
|
||||
image_context_map: List[Tuple[Image.Image, str]] = []
|
||||
pictures: List[PictureItem] = []
|
||||
|
||||
for el in element_batch:
|
||||
assert isinstance(el.item, PictureItem)
|
||||
describe_image = True
|
||||
@ -74,16 +142,31 @@ class PictureDescriptionBaseModel(
|
||||
if area_fraction < self.options.picture_area_threshold:
|
||||
describe_image = False
|
||||
if describe_image:
|
||||
elements.append(el.item)
|
||||
images.append(el.image)
|
||||
pictures.append(el.item)
|
||||
context = ""
|
||||
if (
|
||||
self.options.text_context_window_size_before_picture > 0
|
||||
or self.options.text_context_window_size_after_picture > 0
|
||||
):
|
||||
# Get the surrounding text context
|
||||
context = self._get_surrounding_text(doc, el.item)
|
||||
image_context_map.append((el.image, context))
|
||||
|
||||
outputs = self._annotate_images(images)
|
||||
|
||||
for item, output in zip(elements, outputs):
|
||||
item.annotations.append(
|
||||
PictureDescriptionData(text=output, provenance=self.provenance)
|
||||
if (
|
||||
self.options.text_context_window_size_before_picture > 0
|
||||
or self.options.text_context_window_size_after_picture > 0
|
||||
):
|
||||
picture_descriptions = self._annotate_with_context(image_context_map)
|
||||
else:
|
||||
picture_descriptions = self._annotate_images(
|
||||
image for image, _ in image_context_map
|
||||
)
|
||||
yield item
|
||||
|
||||
for picture, description in zip(pictures, picture_descriptions):
|
||||
picture.annotations.append(
|
||||
PictureDescriptionData(text=description, provenance=self.provenance)
|
||||
)
|
||||
yield picture
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type, Union
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@ -12,6 +13,8 @@ from docling.datamodel.pipeline_options import (
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||
@classmethod
|
||||
@ -121,3 +124,42 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||
)
|
||||
|
||||
yield generated_texts[0].strip()
|
||||
|
||||
def _annotate_with_context(
|
||||
self, image_context_map: Iterable[Tuple[Image.Image, str]]
|
||||
) -> Iterable[str]:
|
||||
from transformers import GenerationConfig
|
||||
|
||||
for image, context in image_context_map:
|
||||
# Create input messages with context
|
||||
context_prompt = f"{context}\n{self.options.prompt}"
|
||||
_log.debug("Prompt: %s", context_prompt)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": context_prompt},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Prepare inputs
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
|
||||
inputs = inputs.to(self.device)
|
||||
|
||||
# Generate outputs
|
||||
generated_ids = self.model.generate(
|
||||
**inputs,
|
||||
generation_config=GenerationConfig(**self.options.generation_config),
|
||||
)
|
||||
generated_texts = self.processor.batch_decode(
|
||||
generated_ids[:, inputs["input_ids"].shape[1] :],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
yield generated_texts[0].strip()
|
||||
|
@ -283,3 +283,9 @@ branch = "main"
|
||||
parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test"
|
||||
parser_angular_minor_types = "feat"
|
||||
parser_angular_patch_types = "fix,perf"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
log_cli = true
|
||||
log_cli_level = "INFO"
|
||||
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
||||
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
150
tests/test_picture_description.py
Normal file
150
tests/test_picture_description.py
Normal file
@ -0,0 +1,150 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from docling_core.types.doc.document import PictureDescriptionData
|
||||
|
||||
from docling.datamodel.base_models import ConversionStatus, InputFormat
|
||||
from docling.datamodel.pipeline_options import (
|
||||
PdfPipelineOptions,
|
||||
PictureDescriptionApiOptions,
|
||||
PictureDescriptionVlmOptions,
|
||||
)
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
|
||||
# Configure logging at the top of the file
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
IMAGE_RESOLUTION_SCALE = 2.0
|
||||
LOCAL_VISION_MODEL = "ibm-granite/granite-vision-3.2-2b"
|
||||
# LOCAL_VISION_MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct"
|
||||
API_VISION_MODEL = "granite3.2-vision:2b"
|
||||
REMOTE_CHAT_API_URL = "http://localhost:8321/v1/openai/v1/chat/completions" # for llama-stack OpenAI API interface
|
||||
DOC_SOURCE = "https://www.allspringglobal.com/globalassets/assets/regulatory/summary-prospectus/emerging-markets-equity-summ.pdf"
|
||||
PROMPT = (
|
||||
"Please describe the image using the text above as additional context. "
|
||||
"Additionally, if only the image contains a chart (like bar chat, pie chat, line chat, etc.), "
|
||||
"please try to extract a list of data points (percentages, numbers, etc) that are depicted in the chart. "
|
||||
"Also, based on the type of information extracted, "
|
||||
"when applicable try to summarize it using bullet points or even a tabular representation using markdown if possible."
|
||||
)
|
||||
|
||||
|
||||
def is_api_available(url: str, timeout: int = 3) -> bool:
|
||||
try:
|
||||
requests.get(url, timeout=timeout)
|
||||
return True
|
||||
except (requests.ConnectionError, requests.Timeout) as e:
|
||||
_log.debug(f"API endpoint {url} is not reachable: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
def process_document(pipeline_options: PdfPipelineOptions):
|
||||
# Initialize document converter
|
||||
doc_converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
|
||||
}
|
||||
)
|
||||
|
||||
# Convert test document
|
||||
_log.info(f"Converting {DOC_SOURCE} with VLM API")
|
||||
conversion_result = doc_converter.convert(source=DOC_SOURCE)
|
||||
|
||||
# Basic conversion checks
|
||||
assert conversion_result.status == ConversionStatus.SUCCESS
|
||||
doc = conversion_result.document
|
||||
assert doc is not None
|
||||
|
||||
# Verify pictures were processed
|
||||
assert len(doc.pictures) > 0
|
||||
|
||||
# Check each picture for descriptions
|
||||
for picture in doc.pictures:
|
||||
# Not every picture has a annotations (eg. some pictures are too small (based on the threshold param (5% of the page area by default))
|
||||
# and gets ignored by the conversion Pipeline)
|
||||
if len(picture.annotations) > 0:
|
||||
# Get the description
|
||||
descriptions = [
|
||||
ann
|
||||
for ann in picture.annotations
|
||||
if isinstance(ann, PictureDescriptionData)
|
||||
]
|
||||
assert len(descriptions) > 0
|
||||
|
||||
# Verify each description is non-empty
|
||||
for desc in descriptions:
|
||||
assert isinstance(desc.text, str)
|
||||
assert len(desc.text) > 0
|
||||
_log.info(
|
||||
f"\nPicture ref: {picture.get_ref().cref}, page #{picture.prov[0].page_no}"
|
||||
)
|
||||
_log.info(f"\tGenerated description: {desc.text}")
|
||||
else:
|
||||
_log.info(
|
||||
f"Picture {picture.get_ref().cref} has no annotations (too small?)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_api_available(REMOTE_CHAT_API_URL),
|
||||
reason="Remote API endpoint is not accessible",
|
||||
)
|
||||
def test_picture_description_context_api_integration():
|
||||
"""Test that the context windows functionality works correctly in the picture description pipeline using a VLM served via API"""
|
||||
# Setup pipeline options with context windows
|
||||
pipeline_options = PdfPipelineOptions(
|
||||
images_scale=IMAGE_RESOLUTION_SCALE,
|
||||
do_picture_description=True,
|
||||
generate_picture_images=True,
|
||||
enable_remote_services=True,
|
||||
picture_description_options=PictureDescriptionApiOptions(
|
||||
url=REMOTE_CHAT_API_URL,
|
||||
params=dict(model=API_VISION_MODEL),
|
||||
text_context_window_size_before_picture=2, # Get 2 text items before
|
||||
text_context_window_size_after_picture=1, # Get 1 text item after
|
||||
prompt=PROMPT,
|
||||
timeout=90,
|
||||
),
|
||||
)
|
||||
|
||||
process_document(pipeline_options)
|
||||
|
||||
|
||||
def test_picture_description_context_vlm_integration():
|
||||
"""Test that the context windows functionality works correctly in the picture description pipeline"""
|
||||
# Setup pipeline options with context windows
|
||||
pipeline_options = PdfPipelineOptions(
|
||||
images_scale=IMAGE_RESOLUTION_SCALE,
|
||||
generate_page_images=True,
|
||||
do_picture_description=True,
|
||||
generate_picture_images=True,
|
||||
picture_description_options=PictureDescriptionVlmOptions(
|
||||
repo_id=LOCAL_VISION_MODEL,
|
||||
text_context_window_size_before_picture=2, # Get 2 text items before
|
||||
text_context_window_size_after_picture=1, # Get 1 text item after
|
||||
prompt=PROMPT,
|
||||
),
|
||||
)
|
||||
|
||||
process_document(pipeline_options)
|
||||
|
||||
|
||||
def test_picture_description_no_context_vlm_integration():
|
||||
"""Test that the picture description works without context windows"""
|
||||
# Setup pipeline options without context windows
|
||||
pipeline_options = PdfPipelineOptions(
|
||||
images_scale=IMAGE_RESOLUTION_SCALE,
|
||||
do_picture_description=True,
|
||||
generate_picture_images=True,
|
||||
picture_description_options=PictureDescriptionVlmOptions(
|
||||
repo_id=LOCAL_VISION_MODEL,
|
||||
text_context_window_size_before_picture=0, # No text context
|
||||
text_context_window_size_after_picture=0, # No text context
|
||||
prompt=PROMPT,
|
||||
),
|
||||
)
|
||||
|
||||
process_document(pipeline_options)
|
Loading…
Reference in New Issue
Block a user