vlm description using AutoModelForVision2Seq

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-02-06 10:55:31 +01:00
parent dae30a48aa
commit 11e27930c4
10 changed files with 238 additions and 103 deletions

View File

@ -219,6 +219,10 @@ def convert(
bool, bool,
typer.Option(..., help="Enable the formula enrichment model in the pipeline."), typer.Option(..., help="Enable the formula enrichment model in the pipeline."),
] = False, ] = False,
enrich_picture_desc: Annotated[
bool,
typer.Option(..., help="Enable the picture description model in the pipeline."),
] = False,
artifacts_path: Annotated[ artifacts_path: Annotated[
Optional[Path], Optional[Path],
typer.Option(..., help="If provided, the location of the model artifacts."), typer.Option(..., help="If provided, the location of the model artifacts."),
@ -375,6 +379,7 @@ def convert(
do_table_structure=True, do_table_structure=True,
do_code_enrichment=enrich_code, do_code_enrichment=enrich_code,
do_formula_enrichment=enrich_formula, do_formula_enrichment=enrich_formula,
do_picture_description=enrich_picture_desc,
document_timeout=document_timeout, document_timeout=document_timeout,
) )
pipeline_options.table_structure_options.do_cell_matching = ( pipeline_options.table_structure_options.do_cell_matching = (

View File

@ -197,7 +197,7 @@ class PicDescBaseOptions(BaseModel):
class PicDescApiOptions(PicDescBaseOptions): class PicDescApiOptions(PicDescBaseOptions):
kind: Literal["api"] = "api" kind: Literal["api"] = "api"
url: AnyUrl = AnyUrl("") url: AnyUrl = AnyUrl("http://localhost/")
headers: Dict[str, str] = {} headers: Dict[str, str] = {}
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
timeout: float = 20 timeout: float = 20
@ -206,22 +206,29 @@ class PicDescApiOptions(PicDescBaseOptions):
provenance: str = "" provenance: str = ""
class PicDescVllmOptions(PicDescBaseOptions): class PicDescVlmOptions(PicDescBaseOptions):
kind: Literal["vllm"] = "vllm" kind: Literal["vlm"] = "vlm"
# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html repo_id: str
prompt: str = "Describe this image in a few sentences."
max_new_tokens: int = 200
# Parameters for LLaVA-1.6/LLaVA-NeXT
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
llm_extra: Dict[str, Any] = dict(max_model_len=8192)
# Parameters for Phi-3-Vision # class PicDescSmolVlmOptions(PicDescVlmOptions):
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct" # repo_id: str = "HuggingFaceTB/SmolVLM-256M-Instruct"
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)
sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)
# class PicDescGraniteOptions(PicDescVlmOptions):
# repo_id: str = "ibm-granite/granite-vision-3.1-2b-preview"
# prompt: str = "What is shown in this image?"
smolvlm_pic_desc = PicDescVlmOptions(repo_id="HuggingFaceTB/SmolVLM-256M-Instruct")
# phi_pic_desc = PicDescVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct")
granite_pic_desc = PicDescVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
prompt="What is shown in this image?",
)
# Define an enum for the backend options # Define an enum for the backend options
@ -274,8 +281,8 @@ class PdfPipelineOptions(PipelineOptions):
RapidOcrOptions, RapidOcrOptions,
] = Field(EasyOcrOptions(), discriminator="kind") ] = Field(EasyOcrOptions(), discriminator="kind")
picture_description_options: Annotated[ picture_description_options: Annotated[
Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind") Union[PicDescApiOptions, PicDescVlmOptions], Field(discriminator="kind")
] = PicDescApiOptions() # TODO: needs defaults or optional ] = smolvlm_pic_desc
images_scale: float = 1.0 images_scale: float = 1.0
generate_page_images: bool = False generate_page_images: bool = False

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Generic, Iterable, Optional from typing import Any, Generic, Iterable, Optional
from docling_core.types.doc import BoundingBox, DoclingDocument, NodeItem, TextItem from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from typing_extensions import TypeVar from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
@ -61,7 +61,7 @@ class BaseItemAndImageEnrichmentModel(
if not self.is_processable(doc=conv_res.document, element=element): if not self.is_processable(doc=conv_res.document, element=element):
return None return None
assert isinstance(element, TextItem) assert isinstance(element, DocItem)
element_prov = element.prov[0] element_prov = element.prov[0]
bbox = element_prov.bbox bbox = element_prov.bbox

View File

@ -1,6 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Iterable from typing import Any, Iterable, List, Optional, Union
from docling_core.types.doc import ( from docling_core.types.doc import (
DoclingDocument, DoclingDocument,
@ -11,36 +11,54 @@ from docling_core.types.doc import (
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData, PictureDescriptionData,
) )
from PIL import Image
from docling.datamodel.pipeline_options import PicDescBaseOptions from docling.datamodel.pipeline_options import PicDescBaseOptions
from docling.models.base_model import BaseEnrichmentModel from docling.models.base_model import (
BaseItemAndImageEnrichmentModel,
ItemAndImageEnrichmentElement,
)
class PictureDescriptionBaseModel(BaseEnrichmentModel): class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel):
images_scale: float = 2.0
def __init__(self, enabled: bool, options: PicDescBaseOptions): def __init__(
self,
enabled: bool,
options: PicDescBaseOptions,
):
self.enabled = enabled self.enabled = enabled
self.options = options self.options = options
self.provenance = "TODO" self.provenance = "not-implemented"
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
# TODO: once the image classifier is active, we can differentiate among image types
return self.enabled and isinstance(element, PictureItem) return self.enabled and isinstance(element, PictureItem)
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData: def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
raise NotImplementedError raise NotImplementedError
def __call__( def __call__(
self, doc: DoclingDocument, element_batch: Iterable[NodeItem] self,
) -> Iterable[Any]: doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
if not self.enabled: if not self.enabled:
for element in element_batch:
yield element.item
return return
for element in element_batch: images: List[Image.Image] = []
assert isinstance(element, PictureItem) elements: List[PictureItem] = []
assert element.image is not None for el in element_batch:
assert isinstance(el.item, PictureItem)
elements.append(el.item)
images.append(el.image)
annotation = self._annotate_image(element) outputs = self._annotate_images(images)
element.annotations.append(annotation)
yield element for item, output in zip(elements, outputs):
item.annotations.append(
PictureDescriptionData(text=output, provenance=self.provenance)
)
yield item

View File

@ -1,59 +0,0 @@
import json
from typing import List
from docling_core.types.doc import PictureItem
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from docling.datamodel.pipeline_options import PicDescVllmOptions
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
class PictureDescriptionVllmModel(PictureDescriptionBaseModel):
def __init__(self, enabled: bool, options: PicDescVllmOptions):
super().__init__(enabled=enabled, options=options)
self.options: PicDescVllmOptions
if self.enabled:
raise NotImplementedError
if self.enabled:
try:
from vllm import LLM, SamplingParams # type: ignore
except ImportError:
raise ImportError(
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
)
self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore
# Generate a stable hash from the extra parameters
def create_hash(t):
return ""
params_hash = create_hash(
json.dumps(self.options.llm_extra, sort_keys=True)
+ json.dumps(self.options.sampling_params, sort_keys=True)
)
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
assert picture.image is not None
from vllm import RequestOutput
inputs = [
{
"prompt": self.options.llm_prompt,
"multi_modal_data": {"image": picture.image.pil_image},
}
]
outputs: List[RequestOutput] = self.llm.generate( # type: ignore
inputs, sampling_params=self.sampling_params # type: ignore
)
generated_text = outputs[0].outputs[0].text
return PictureDescriptionData(provenance=self.provenance, text=generated_text)

View File

@ -0,0 +1,104 @@
import json
from pathlib import Path
from typing import Iterable, List, Optional, Union
from PIL import Image
from docling.datamodel.pipeline_options import AcceleratorOptions, PicDescVlmOptions
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
from docling.utils.accelerator_utils import decide_device
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Union[Path, str]],
options: PicDescVlmOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(enabled=enabled, options=options)
self.options: PicDescVlmOptions
if self.enabled:
if artifacts_path is None:
artifacts_path = self.download_models(repo_id=self.options.repo_id)
self.device = decide_device(accelerator_options.device)
try:
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
except ImportError:
raise ImportError(
"transformers >=4.46 is not installed. Please install Docling with the required extras `pip install docling[vlm]`."
)
# Initialize processor and model
self.processor = AutoProcessor.from_pretrained(self.options.repo_id)
self.model = AutoModelForVision2Seq.from_pretrained(
self.options.repo_id,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2" if self.device.startswith("cuda") else "eager"
),
).to(self.device)
self.provenance = f"{self.options.repo_id}"
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
)
return Path(download_path)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
# Create input messages
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": self.options.prompt},
],
},
]
# TODO: set seed for reproducibility
# TODO: do batch generation
for image in images:
# Prepare inputs
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=True
)
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(self.device)
# Generate outputs
generated_ids = self.model.generate(
**inputs, max_new_tokens=self.options.max_new_tokens
)
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)
yield generated_texts[0].strip()

View File

@ -14,7 +14,7 @@ from docling.datamodel.pipeline_options import (
OcrMacOptions, OcrMacOptions,
PdfPipelineOptions, PdfPipelineOptions,
PicDescApiOptions, PicDescApiOptions,
PicDescVllmOptions, PicDescVlmOptions,
RapidOcrOptions, RapidOcrOptions,
TesseractCliOcrOptions, TesseractCliOcrOptions,
TesseractOcrOptions, TesseractOcrOptions,
@ -36,7 +36,7 @@ from docling.models.page_preprocessing_model import (
) )
from docling.models.pic_description_api_model import PictureDescriptionApiModel from docling.models.pic_description_api_model import PictureDescriptionApiModel
from docling.models.pic_description_base_model import PictureDescriptionBaseModel from docling.models.pic_description_base_model import PictureDescriptionBaseModel
from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel from docling.models.pic_description_vlm_model import PictureDescriptionVlmModel
from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.table_structure_model import TableStructureModel from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
@ -132,6 +132,7 @@ class StandardPdfPipeline(PaginatedPipeline):
if ( if (
self.pipeline_options.do_formula_enrichment self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment or self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_picture_description
): ):
self.keep_backend = True self.keep_backend = True
@ -186,7 +187,9 @@ class StandardPdfPipeline(PaginatedPipeline):
) )
return None return None
def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]: def get_pic_description_model(
self, artifacts_path: Optional[Path] = None
) -> Optional[PictureDescriptionBaseModel]:
if isinstance( if isinstance(
self.pipeline_options.picture_description_options, PicDescApiOptions self.pipeline_options.picture_description_options, PicDescApiOptions
): ):
@ -195,11 +198,13 @@ class StandardPdfPipeline(PaginatedPipeline):
options=self.pipeline_options.picture_description_options, options=self.pipeline_options.picture_description_options,
) )
elif isinstance( elif isinstance(
self.pipeline_options.picture_description_options, PicDescVllmOptions self.pipeline_options.picture_description_options, PicDescVlmOptions
): ):
return PictureDescriptionVllmModel( return PictureDescriptionVlmModel(
enabled=self.pipeline_options.do_picture_description, enabled=self.pipeline_options.do_picture_description,
artifacts_path=artifacts_path,
options=self.pipeline_options.picture_description_options, options=self.pipeline_options.picture_description_options,
accelerator_options=self.pipeline_options.accelerator_options,
) )
return None return None

View File

@ -0,0 +1,48 @@
import logging
from pathlib import Path
from docling_core.types.doc import PictureItem
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import ( # PicDescSmolVlmOptions, PicDescGraniteOptions
PdfPipelineOptions,
granite_pic_desc,
smolvlm_pic_desc,
)
from docling.document_converter import DocumentConverter, PdfFormatOption
def main():
logging.basicConfig(level=logging.INFO)
input_doc_path = Path("./tests/data/2206.01062.pdf")
pipeline_options = PdfPipelineOptions()
pipeline_options.do_picture_description = True
pipeline_options.picture_description_options = smolvlm_pic_desc
# pipeline_options.picture_description_options = granite_pic_desc
pipeline_options.picture_description_options.prompt = (
"Describe the image in three sentences. Be consise and accurate."
)
doc_converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options,
)
}
)
result = doc_converter.convert(input_doc_path)
for element, _level in result.document.iterate_items():
if isinstance(element, PictureItem):
print(
f"Picture {element.self_ref}\n"
f"Caption: {element.caption_text(doc=result.document)}\n"
f"Annotations: {element.annotations}"
)
if __name__ == "__main__":
main()

13
poetry.lock generated
View File

@ -3823,10 +3823,10 @@ files = [
numpy = [ numpy = [
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
] ]
[[package]] [[package]]
@ -3849,10 +3849,10 @@ files = [
numpy = [ numpy = [
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
] ]
[[package]] [[package]]
@ -4037,8 +4037,8 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
pytz = ">=2020.1" pytz = ">=2020.1"
@ -7747,8 +7747,9 @@ type = ["pytest-mypy"]
ocrmac = ["ocrmac"] ocrmac = ["ocrmac"]
rapidocr = ["onnxruntime", "onnxruntime", "rapidocr-onnxruntime"] rapidocr = ["onnxruntime", "onnxruntime", "rapidocr-onnxruntime"]
tesserocr = ["tesserocr"] tesserocr = ["tesserocr"]
vlm = ["transformers", "transformers"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "08d30cee8d77f9beee32d5dbec1643367ecae2b4c4b47b57fcb337711471eb5c" content-hash = "c1c121c7b5650bf37611765224d9628ad814d440d1e7e9d5c959d97a8e16f94c"

View File

@ -56,6 +56,10 @@ onnxruntime = [
{ version = ">=1.7.0,<1.20.0", optional = true, markers = "python_version < '3.10'" }, { version = ">=1.7.0,<1.20.0", optional = true, markers = "python_version < '3.10'" },
{ version = "^1.7.0", optional = true, markers = "python_version >= '3.10'" } { version = "^1.7.0", optional = true, markers = "python_version >= '3.10'" }
] ]
transformers = [
{markers = "sys_platform != 'darwin' or platform_machine != 'x86_64'", version = "^4.46.0", optional = true },
{markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", version = "~4.42.0", optional = true }
]
pillow = "^10.0.0" pillow = "^10.0.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
@ -116,6 +120,7 @@ torchvision = [
[tool.poetry.extras] [tool.poetry.extras]
tesserocr = ["tesserocr"] tesserocr = ["tesserocr"]
ocrmac = ["ocrmac"] ocrmac = ["ocrmac"]
vlm = ["transformers"]
rapidocr = ["rapidocr-onnxruntime", "onnxruntime"] rapidocr = ["rapidocr-onnxruntime", "onnxruntime"]
[tool.poetry.scripts] [tool.poetry.scripts]
@ -156,7 +161,8 @@ module = [
"deepsearch_glm.*", "deepsearch_glm.*",
"lxml.*", "lxml.*",
"bs4.*", "bs4.*",
"huggingface_hub.*" "huggingface_hub.*",
"transformers.*",
] ]
ignore_missing_imports = true ignore_missing_imports = true