From 11e27930c408c2a01b3ce62d8b353e4fa7fcbdb4 Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Thu, 6 Feb 2025 10:55:31 +0100 Subject: [PATCH] vlm description using AutoModelForVision2Seq Signed-off-by: Michele Dolfi --- docling/cli/main.py | 5 + docling/datamodel/pipeline_options.py | 37 ++++--- docling/models/base_model.py | 4 +- docling/models/pic_description_base_model.py | 48 ++++++--- docling/models/pic_description_vllm_model.py | 59 ----------- docling/models/pic_description_vlm_model.py | 104 +++++++++++++++++++ docling/pipeline/standard_pdf_pipeline.py | 15 ++- docs/examples/pictures_description.py | 48 +++++++++ poetry.lock | 13 +-- pyproject.toml | 8 +- 10 files changed, 238 insertions(+), 103 deletions(-) delete mode 100644 docling/models/pic_description_vllm_model.py create mode 100644 docling/models/pic_description_vlm_model.py create mode 100644 docs/examples/pictures_description.py diff --git a/docling/cli/main.py b/docling/cli/main.py index 7d31221d..a369c1b6 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -219,6 +219,10 @@ def convert( bool, typer.Option(..., help="Enable the formula enrichment model in the pipeline."), ] = False, + enrich_picture_desc: Annotated[ + bool, + typer.Option(..., help="Enable the picture description model in the pipeline."), + ] = False, artifacts_path: Annotated[ Optional[Path], typer.Option(..., help="If provided, the location of the model artifacts."), @@ -375,6 +379,7 @@ def convert( do_table_structure=True, do_code_enrichment=enrich_code, do_formula_enrichment=enrich_formula, + do_picture_description=enrich_picture_desc, document_timeout=document_timeout, ) pipeline_options.table_structure_options.do_cell_matching = ( diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 80eb26b9..5f35b520 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -197,7 +197,7 @@ class PicDescBaseOptions(BaseModel): class PicDescApiOptions(PicDescBaseOptions): kind: Literal["api"] = "api" - url: AnyUrl = AnyUrl("") + url: AnyUrl = AnyUrl("http://localhost/") headers: Dict[str, str] = {} params: Dict[str, Any] = {} timeout: float = 20 @@ -206,22 +206,29 @@ class PicDescApiOptions(PicDescBaseOptions): provenance: str = "" -class PicDescVllmOptions(PicDescBaseOptions): - kind: Literal["vllm"] = "vllm" +class PicDescVlmOptions(PicDescBaseOptions): + kind: Literal["vlm"] = "vlm" - # For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html + repo_id: str + prompt: str = "Describe this image in a few sentences." + max_new_tokens: int = 200 - # Parameters for LLaVA-1.6/LLaVA-NeXT - llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf" - llm_prompt: str = "[INST] \nDescribe the image in details. [/INST]" - llm_extra: Dict[str, Any] = dict(max_model_len=8192) - # Parameters for Phi-3-Vision - # llm_name: str = "microsoft/Phi-3-vision-128k-instruct" - # llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n" - # llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True) +# class PicDescSmolVlmOptions(PicDescVlmOptions): +# repo_id: str = "HuggingFaceTB/SmolVLM-256M-Instruct" - sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42) + +# class PicDescGraniteOptions(PicDescVlmOptions): +# repo_id: str = "ibm-granite/granite-vision-3.1-2b-preview" +# prompt: str = "What is shown in this image?" + + +smolvlm_pic_desc = PicDescVlmOptions(repo_id="HuggingFaceTB/SmolVLM-256M-Instruct") +# phi_pic_desc = PicDescVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct") +granite_pic_desc = PicDescVlmOptions( + repo_id="ibm-granite/granite-vision-3.1-2b-preview", + prompt="What is shown in this image?", +) # Define an enum for the backend options @@ -274,8 +281,8 @@ class PdfPipelineOptions(PipelineOptions): RapidOcrOptions, ] = Field(EasyOcrOptions(), discriminator="kind") picture_description_options: Annotated[ - Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind") - ] = PicDescApiOptions() # TODO: needs defaults or optional + Union[PicDescApiOptions, PicDescVlmOptions], Field(discriminator="kind") + ] = smolvlm_pic_desc images_scale: float = 1.0 generate_page_images: bool = False diff --git a/docling/models/base_model.py b/docling/models/base_model.py index 08d728cc..7fe5c0a7 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Iterable, Optional -from docling_core.types.doc import BoundingBox, DoclingDocument, NodeItem, TextItem +from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem from typing_extensions import TypeVar from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page @@ -61,7 +61,7 @@ class BaseItemAndImageEnrichmentModel( if not self.is_processable(doc=conv_res.document, element=element): return None - assert isinstance(element, TextItem) + assert isinstance(element, DocItem) element_prov = element.prov[0] bbox = element_prov.bbox diff --git a/docling/models/pic_description_base_model.py b/docling/models/pic_description_base_model.py index 673ffd90..9be9e678 100644 --- a/docling/models/pic_description_base_model.py +++ b/docling/models/pic_description_base_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Iterable +from typing import Any, Iterable, List, Optional, Union from docling_core.types.doc import ( DoclingDocument, @@ -11,36 +11,54 @@ from docling_core.types.doc import ( from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc PictureDescriptionData, ) +from PIL import Image from docling.datamodel.pipeline_options import PicDescBaseOptions -from docling.models.base_model import BaseEnrichmentModel +from docling.models.base_model import ( + BaseItemAndImageEnrichmentModel, + ItemAndImageEnrichmentElement, +) -class PictureDescriptionBaseModel(BaseEnrichmentModel): +class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): + images_scale: float = 2.0 - def __init__(self, enabled: bool, options: PicDescBaseOptions): + def __init__( + self, + enabled: bool, + options: PicDescBaseOptions, + ): self.enabled = enabled self.options = options - self.provenance = "TODO" + self.provenance = "not-implemented" def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: - # TODO: once the image classifier is active, we can differentiate among image types return self.enabled and isinstance(element, PictureItem) - def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData: + def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: raise NotImplementedError def __call__( - self, doc: DoclingDocument, element_batch: Iterable[NodeItem] - ) -> Iterable[Any]: + self, + doc: DoclingDocument, + element_batch: Iterable[ItemAndImageEnrichmentElement], + ) -> Iterable[NodeItem]: if not self.enabled: + for element in element_batch: + yield element.item return - for element in element_batch: - assert isinstance(element, PictureItem) - assert element.image is not None + images: List[Image.Image] = [] + elements: List[PictureItem] = [] + for el in element_batch: + assert isinstance(el.item, PictureItem) + elements.append(el.item) + images.append(el.image) - annotation = self._annotate_image(element) - element.annotations.append(annotation) + outputs = self._annotate_images(images) - yield element + for item, output in zip(elements, outputs): + item.annotations.append( + PictureDescriptionData(text=output, provenance=self.provenance) + ) + yield item diff --git a/docling/models/pic_description_vllm_model.py b/docling/models/pic_description_vllm_model.py deleted file mode 100644 index 84a28104..00000000 --- a/docling/models/pic_description_vllm_model.py +++ /dev/null @@ -1,59 +0,0 @@ -import json -from typing import List - -from docling_core.types.doc import PictureItem -from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc - PictureDescriptionData, -) - -from docling.datamodel.pipeline_options import PicDescVllmOptions -from docling.models.pic_description_base_model import PictureDescriptionBaseModel - - -class PictureDescriptionVllmModel(PictureDescriptionBaseModel): - - def __init__(self, enabled: bool, options: PicDescVllmOptions): - super().__init__(enabled=enabled, options=options) - self.options: PicDescVllmOptions - - if self.enabled: - raise NotImplementedError - - if self.enabled: - try: - from vllm import LLM, SamplingParams # type: ignore - except ImportError: - raise ImportError( - "VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`." - ) - - self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore - self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore - - # Generate a stable hash from the extra parameters - def create_hash(t): - return "" - - params_hash = create_hash( - json.dumps(self.options.llm_extra, sort_keys=True) - + json.dumps(self.options.sampling_params, sort_keys=True) - ) - self.provenance = f"{self.options.llm_name}-{params_hash[:8]}" - - def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData: - assert picture.image is not None - - from vllm import RequestOutput - - inputs = [ - { - "prompt": self.options.llm_prompt, - "multi_modal_data": {"image": picture.image.pil_image}, - } - ] - outputs: List[RequestOutput] = self.llm.generate( # type: ignore - inputs, sampling_params=self.sampling_params # type: ignore - ) - - generated_text = outputs[0].outputs[0].text - return PictureDescriptionData(provenance=self.provenance, text=generated_text) diff --git a/docling/models/pic_description_vlm_model.py b/docling/models/pic_description_vlm_model.py new file mode 100644 index 00000000..3103c405 --- /dev/null +++ b/docling/models/pic_description_vlm_model.py @@ -0,0 +1,104 @@ +import json +from pathlib import Path +from typing import Iterable, List, Optional, Union + +from PIL import Image + +from docling.datamodel.pipeline_options import AcceleratorOptions, PicDescVlmOptions +from docling.models.pic_description_base_model import PictureDescriptionBaseModel +from docling.utils.accelerator_utils import decide_device + + +class PictureDescriptionVlmModel(PictureDescriptionBaseModel): + + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Union[Path, str]], + options: PicDescVlmOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__(enabled=enabled, options=options) + self.options: PicDescVlmOptions + + if self.enabled: + + if artifacts_path is None: + artifacts_path = self.download_models(repo_id=self.options.repo_id) + + self.device = decide_device(accelerator_options.device) + + try: + import torch + from transformers import AutoModelForVision2Seq, AutoProcessor + except ImportError: + raise ImportError( + "transformers >=4.46 is not installed. Please install Docling with the required extras `pip install docling[vlm]`." + ) + + # Initialize processor and model + self.processor = AutoProcessor.from_pretrained(self.options.repo_id) + self.model = AutoModelForVision2Seq.from_pretrained( + self.options.repo_id, + torch_dtype=torch.bfloat16, + _attn_implementation=( + "flash_attention_2" if self.device.startswith("cuda") else "eager" + ), + ).to(self.device) + + self.provenance = f"{self.options.repo_id}" + + @staticmethod + def download_models( + repo_id: str, + local_dir: Optional[Path] = None, + force: bool = False, + progress: bool = False, + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + if not progress: + disable_progress_bars() + download_path = snapshot_download( + repo_id=repo_id, + force_download=force, + local_dir=local_dir, + ) + + return Path(download_path) + + def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + + # Create input messages + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": self.options.prompt}, + ], + }, + ] + + # TODO: set seed for reproducibility + # TODO: do batch generation + + for image in images: + # Prepare inputs + prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True + ) + inputs = self.processor(text=prompt, images=[image], return_tensors="pt") + inputs = inputs.to(self.device) + + # Generate outputs + generated_ids = self.model.generate( + **inputs, max_new_tokens=self.options.max_new_tokens + ) + generated_texts = self.processor.batch_decode( + generated_ids[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + ) + + yield generated_texts[0].strip() diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 1608f43d..c0faaf0e 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -14,7 +14,7 @@ from docling.datamodel.pipeline_options import ( OcrMacOptions, PdfPipelineOptions, PicDescApiOptions, - PicDescVllmOptions, + PicDescVlmOptions, RapidOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions, @@ -36,7 +36,7 @@ from docling.models.page_preprocessing_model import ( ) from docling.models.pic_description_api_model import PictureDescriptionApiModel from docling.models.pic_description_base_model import PictureDescriptionBaseModel -from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel +from docling.models.pic_description_vlm_model import PictureDescriptionVlmModel from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.table_structure_model import TableStructureModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel @@ -132,6 +132,7 @@ class StandardPdfPipeline(PaginatedPipeline): if ( self.pipeline_options.do_formula_enrichment or self.pipeline_options.do_code_enrichment + or self.pipeline_options.do_picture_description ): self.keep_backend = True @@ -186,7 +187,9 @@ class StandardPdfPipeline(PaginatedPipeline): ) return None - def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]: + def get_pic_description_model( + self, artifacts_path: Optional[Path] = None + ) -> Optional[PictureDescriptionBaseModel]: if isinstance( self.pipeline_options.picture_description_options, PicDescApiOptions ): @@ -195,11 +198,13 @@ class StandardPdfPipeline(PaginatedPipeline): options=self.pipeline_options.picture_description_options, ) elif isinstance( - self.pipeline_options.picture_description_options, PicDescVllmOptions + self.pipeline_options.picture_description_options, PicDescVlmOptions ): - return PictureDescriptionVllmModel( + return PictureDescriptionVlmModel( enabled=self.pipeline_options.do_picture_description, + artifacts_path=artifacts_path, options=self.pipeline_options.picture_description_options, + accelerator_options=self.pipeline_options.accelerator_options, ) return None diff --git a/docs/examples/pictures_description.py b/docs/examples/pictures_description.py new file mode 100644 index 00000000..d5f72045 --- /dev/null +++ b/docs/examples/pictures_description.py @@ -0,0 +1,48 @@ +import logging +from pathlib import Path + +from docling_core.types.doc import PictureItem + +from docling.datamodel.base_models import InputFormat +from docling.datamodel.pipeline_options import ( # PicDescSmolVlmOptions, PicDescGraniteOptions + PdfPipelineOptions, + granite_pic_desc, + smolvlm_pic_desc, +) +from docling.document_converter import DocumentConverter, PdfFormatOption + + +def main(): + logging.basicConfig(level=logging.INFO) + + input_doc_path = Path("./tests/data/2206.01062.pdf") + + pipeline_options = PdfPipelineOptions() + pipeline_options.do_picture_description = True + pipeline_options.picture_description_options = smolvlm_pic_desc + # pipeline_options.picture_description_options = granite_pic_desc + + pipeline_options.picture_description_options.prompt = ( + "Describe the image in three sentences. Be consise and accurate." + ) + + doc_converter = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption( + pipeline_options=pipeline_options, + ) + } + ) + result = doc_converter.convert(input_doc_path) + + for element, _level in result.document.iterate_items(): + if isinstance(element, PictureItem): + print( + f"Picture {element.self_ref}\n" + f"Caption: {element.caption_text(doc=result.document)}\n" + f"Annotations: {element.annotations}" + ) + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index 0d685fe3..60b085c1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3823,10 +3823,10 @@ files = [ numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -3849,10 +3849,10 @@ files = [ numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4037,8 +4037,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -7747,8 +7747,9 @@ type = ["pytest-mypy"] ocrmac = ["ocrmac"] rapidocr = ["onnxruntime", "onnxruntime", "rapidocr-onnxruntime"] tesserocr = ["tesserocr"] +vlm = ["transformers", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "08d30cee8d77f9beee32d5dbec1643367ecae2b4c4b47b57fcb337711471eb5c" +content-hash = "c1c121c7b5650bf37611765224d9628ad814d440d1e7e9d5c959d97a8e16f94c" diff --git a/pyproject.toml b/pyproject.toml index d5ea2955..235e8015 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,10 @@ onnxruntime = [ { version = ">=1.7.0,<1.20.0", optional = true, markers = "python_version < '3.10'" }, { version = "^1.7.0", optional = true, markers = "python_version >= '3.10'" } ] +transformers = [ + {markers = "sys_platform != 'darwin' or platform_machine != 'x86_64'", version = "^4.46.0", optional = true }, + {markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", version = "~4.42.0", optional = true } +] pillow = "^10.0.0" [tool.poetry.group.dev.dependencies] @@ -116,6 +120,7 @@ torchvision = [ [tool.poetry.extras] tesserocr = ["tesserocr"] ocrmac = ["ocrmac"] +vlm = ["transformers"] rapidocr = ["rapidocr-onnxruntime", "onnxruntime"] [tool.poetry.scripts] @@ -156,7 +161,8 @@ module = [ "deepsearch_glm.*", "lxml.*", "bs4.*", - "huggingface_hub.*" + "huggingface_hub.*", + "transformers.*", ] ignore_missing_imports = true