mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-24 19:14:23 +00:00
feat: add image-text-to-text models in transformers (#1772)
* feat(dolphin): add dolphin support Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com> * rename Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com> * reformat Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com> * fix mypy Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com> * add prompt style and examples Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
e25873d557
commit
a07ba863c4
@ -31,6 +31,12 @@ class TransformersModelType(str, Enum):
|
|||||||
AUTOMODEL = "automodel"
|
AUTOMODEL = "automodel"
|
||||||
AUTOMODEL_VISION2SEQ = "automodel-vision2seq"
|
AUTOMODEL_VISION2SEQ = "automodel-vision2seq"
|
||||||
AUTOMODEL_CAUSALLM = "automodel-causallm"
|
AUTOMODEL_CAUSALLM = "automodel-causallm"
|
||||||
|
AUTOMODEL_IMAGETEXTTOTEXT = "automodel-imagetexttotext"
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersPromptStyle(str, Enum):
|
||||||
|
CHAT = "chat"
|
||||||
|
RAW = "raw"
|
||||||
|
|
||||||
|
|
||||||
class InlineVlmOptions(BaseVlmOptions):
|
class InlineVlmOptions(BaseVlmOptions):
|
||||||
@ -44,6 +50,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
|||||||
|
|
||||||
inference_framework: InferenceFramework
|
inference_framework: InferenceFramework
|
||||||
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
|
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
|
||||||
|
transformers_prompt_style: TransformersPromptStyle = TransformersPromptStyle.CHAT
|
||||||
response_format: ResponseFormat
|
response_format: ResponseFormat
|
||||||
|
|
||||||
torch_dtype: Optional[str] = None
|
torch_dtype: Optional[str] = None
|
||||||
|
@ -13,6 +13,7 @@ from docling.datamodel.document import ConversionResult
|
|||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
InlineVlmOptions,
|
InlineVlmOptions,
|
||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
|
TransformersPromptStyle,
|
||||||
)
|
)
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import (
|
||||||
@ -41,6 +42,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForImageTextToText,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
@ -91,6 +93,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
== TransformersModelType.AUTOMODEL_VISION2SEQ
|
== TransformersModelType.AUTOMODEL_VISION2SEQ
|
||||||
):
|
):
|
||||||
model_cls = AutoModelForVision2Seq
|
model_cls = AutoModelForVision2Seq
|
||||||
|
elif (
|
||||||
|
self.vlm_options.transformers_model_type
|
||||||
|
== TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT
|
||||||
|
):
|
||||||
|
model_cls = AutoModelForImageTextToText
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
@ -169,7 +176,10 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
def formulate_prompt(self, user_prompt: str) -> str:
|
def formulate_prompt(self, user_prompt: str) -> str:
|
||||||
"""Formulate a prompt for the VLM."""
|
"""Formulate a prompt for the VLM."""
|
||||||
|
|
||||||
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||||
|
return user_prompt
|
||||||
|
|
||||||
|
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||||
_log.debug("Using specialized prompt for Phi-4")
|
_log.debug("Using specialized prompt for Phi-4")
|
||||||
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
||||||
|
|
||||||
@ -182,20 +192,25 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
messages = [
|
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
||||||
{
|
messages = [
|
||||||
"role": "user",
|
{
|
||||||
"content": [
|
"role": "user",
|
||||||
{
|
"content": [
|
||||||
"type": "text",
|
{
|
||||||
"text": "This is a page from a document.",
|
"type": "text",
|
||||||
},
|
"text": "This is a page from a document.",
|
||||||
{"type": "image"},
|
},
|
||||||
{"type": "text", "text": user_prompt},
|
{"type": "image"},
|
||||||
],
|
{"type": "text", "text": user_prompt},
|
||||||
}
|
],
|
||||||
]
|
}
|
||||||
prompt = self.processor.apply_chat_template(
|
]
|
||||||
messages, add_generation_prompt=False
|
prompt = self.processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||||
)
|
)
|
||||||
return prompt
|
|
||||||
|
39
docs/examples/compare_vlm_models.py
vendored
39
docs/examples/compare_vlm_models.py
vendored
@ -14,11 +14,18 @@ from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS
|
|||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from docling.datamodel import vlm_model_specs
|
from docling.datamodel import vlm_model_specs
|
||||||
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||||
from docling.datamodel.base_models import InputFormat
|
from docling.datamodel.base_models import InputFormat
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
VlmPipelineOptions,
|
VlmPipelineOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.pipeline_options_vlm_model import InferenceFramework
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
|
InferenceFramework,
|
||||||
|
InlineVlmOptions,
|
||||||
|
ResponseFormat,
|
||||||
|
TransformersModelType,
|
||||||
|
TransformersPromptStyle,
|
||||||
|
)
|
||||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||||
|
|
||||||
@ -101,6 +108,33 @@ if __name__ == "__main__":
|
|||||||
out_path = Path("scratch")
|
out_path = Path("scratch")
|
||||||
out_path.mkdir(parents=True, exist_ok=True)
|
out_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
## Definiton of more inline models
|
||||||
|
llava_qwen = InlineVlmOptions(
|
||||||
|
repo_id="llava-hf/llava-interleave-qwen-0.5b-hf",
|
||||||
|
# prompt="Read text in the image.",
|
||||||
|
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
|
||||||
|
# prompt="Parse the reading order of this document.",
|
||||||
|
response_format=ResponseFormat.MARKDOWN,
|
||||||
|
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||||
|
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||||
|
supported_devices=[AcceleratorDevice.CUDA, AcceleratorDevice.CPU],
|
||||||
|
scale=2.0,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note that this is not the expected way of using the Dolphin model, but it shows the usage of a raw prompt.
|
||||||
|
dolphin_oneshot = InlineVlmOptions(
|
||||||
|
repo_id="ByteDance/Dolphin",
|
||||||
|
prompt="<s>Read text in the image. <Answer/>",
|
||||||
|
response_format=ResponseFormat.MARKDOWN,
|
||||||
|
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||||
|
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||||
|
transformers_prompt_style=TransformersPromptStyle.RAW,
|
||||||
|
supported_devices=[AcceleratorDevice.CUDA, AcceleratorDevice.CPU],
|
||||||
|
scale=2.0,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
## Use VlmPipeline
|
## Use VlmPipeline
|
||||||
pipeline_options = VlmPipelineOptions()
|
pipeline_options = VlmPipelineOptions()
|
||||||
pipeline_options.generate_page_images = True
|
pipeline_options.generate_page_images = True
|
||||||
@ -121,6 +155,9 @@ if __name__ == "__main__":
|
|||||||
vlm_model_specs.GRANITE_VISION_TRANSFORMERS,
|
vlm_model_specs.GRANITE_VISION_TRANSFORMERS,
|
||||||
vlm_model_specs.PHI4_TRANSFORMERS,
|
vlm_model_specs.PHI4_TRANSFORMERS,
|
||||||
vlm_model_specs.PIXTRAL_12B_TRANSFORMERS,
|
vlm_model_specs.PIXTRAL_12B_TRANSFORMERS,
|
||||||
|
## More inline models
|
||||||
|
dolphin_oneshot,
|
||||||
|
llava_qwen,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove MLX models if not on Mac
|
# Remove MLX models if not on Mac
|
||||||
|
Loading…
Reference in New Issue
Block a user