fixed the MyPy

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-14 17:51:43 +02:00
parent a3716b1961
commit 7c67d2b2fe
10 changed files with 158 additions and 245 deletions

View File

@ -29,6 +29,13 @@ from docling.datamodel.base_models import (
OutputFormat,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_model_specializations import (
VlmModelType,
granite_vision_vlm_conversion_options,
granite_vision_vlm_ollama_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
@ -39,12 +46,7 @@ from docling.datamodel.pipeline_options import (
PdfPipeline,
PdfPipelineOptions,
TableFormerMode,
VlmModelType,
VlmPipelineOptions,
granite_vision_vlm_conversion_options,
granite_vision_vlm_ollama_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.datamodel.settings import settings
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption

View File

@ -1,10 +1,6 @@
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from docling_core.types.io import (
DocumentStream,
)
from docling_core.types.doc import (
BoundingBox,
DocItemLabel,
@ -14,6 +10,9 @@ from docling_core.types.doc import (
TableCell,
)
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
from docling_core.types.io import (
DocumentStream,
)
# DO NOT REMOVE; explicitly exposed from this location
from PIL.Image import Image
@ -148,11 +147,13 @@ class BasePageElement(BaseModel):
class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []
class VlmPredictionToken(BaseModel):
text: str = ""
token: int = -1
logprob: float = -1
class VlmPrediction(BaseModel):
text: str = ""
generated_tokens: list[VlmPredictionToken] = []

View File

@ -16,6 +16,12 @@ from pydantic import (
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import deprecated
from docling.datamodel.pipeline_model_specializations import (
ApiVlmOptions,
HuggingFaceVlmOptions,
smoldocling_vlm_conversion_options,
)
_log = logging.getLogger(__name__)
@ -121,24 +127,22 @@ class RapidOcrOptions(OcrOptions):
lang: List[str] = [
"english",
"chinese",
] # However, language as a parameter is not supported by rapidocr yet and hence changing this options doesn't affect anything.
# For more details on supported languages by RapidOCR visit https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
]
# However, language as a parameter is not supported by rapidocr yet
# and hence changing this options doesn't affect anything.
# For more details on supported languages by RapidOCR visit
# https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
# For more details on the following options visit
# https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
# For more details on the following options visit https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
text_score: float = 0.5 # same default as rapidocr
use_det: Optional[bool] = None # same default as rapidocr
use_cls: Optional[bool] = None # same default as rapidocr
use_rec: Optional[bool] = None # same default as rapidocr
# class Device(Enum):
# CPU = "CPU"
# CUDA = "CUDA"
# DIRECTML = "DIRECTML"
# AUTO = "AUTO"
# device: Device = Device.AUTO # Default value is AUTO
print_verbose: bool = False # same default as rapidocr
det_model_path: Optional[str] = None # same default as rapidocr
@ -243,110 +247,18 @@ class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
return self.repo_id.replace("/", "--")
# SmolVLM
smolvlm_picture_description = PictureDescriptionVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"
)
# phi_picture_description = PictureDescriptionVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct")
# GraniteVision
granite_picture_description = PictureDescriptionVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
prompt="What is shown in this image?",
)
class BaseVlmOptions(BaseModel):
kind: str
prompt: str
class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
MARKDOWN = "markdown"
HTML = "html"
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
OPENAI = "openai"
TRANSFORMERS_AutoModelForVision2Seq = "transformers-AutoModelForVision2Seq"
TRANSFORMERS_AutoModelForCausalLM = "transformers-AutoModelForCausalLM"
TRANSFORMERS_LlavaForConditionalGeneration = (
"transformers-LlavaForConditionalGeneration"
)
class HuggingFaceVlmOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"
repo_id: str
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
response_format: ResponseFormat
scale: float = 2.0
use_kv_cache: bool = True
max_new_tokens: int = 4096
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
class ApiVlmOptions(BaseVlmOptions):
kind: Literal["api_model_options"] = "api_model_options"
url: AnyUrl = AnyUrl(
"http://localhost:11434/v1/chat/completions"
) # Default to ollama
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
scale: float = 2.0
timeout: float = 60
response_format: ResponseFormat
smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
)
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
)
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
prompt="OCR the full page to markdown.",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
)
granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
params={"model": "granite3.2-vision:2b"},
prompt="OCR the full page to markdown.",
scale=1.0,
timeout=120,
response_format=ResponseFormat.MARKDOWN,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
GRANITE_VISION = "granite_vision"
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
# Define an enum for the backend options
class PdfBackend(str, Enum):
"""Enum of valid PDF backends."""

View File

@ -94,7 +94,7 @@ class HuggingFaceMlxModel(BasePageModel):
_log.debug("start generating ...")
# Call model to generate:
tokens:list[VlmPredictionToken] = []
tokens: list[VlmPredictionToken] = []
output = ""
for token in self.stream_generate(
@ -105,14 +105,25 @@ class HuggingFaceMlxModel(BasePageModel):
max_tokens=4096,
verbose=False,
):
if len(token.logprobs.shape)==1:
tokens.append(VlmPredictionToken(text=token.text,
token=token.token,
logprob=token.logprobs[token.token]))
elif len(token.logprobs.shape)==2 and token.logprobs.shape[0]==1:
tokens.append(VlmPredictionToken(text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token]))
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
output += token.text
if "</doctag>" in token.text:
@ -121,9 +132,13 @@ class HuggingFaceMlxModel(BasePageModel):
generation_time = time.time() - start_time
page_tags = output
_log.debug(f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens)/generation_time} tokens/sec).")
page.predictions.vlm_response = VlmPrediction(text=page_tags,
generation_time=generation_time,
generated_tokens=tokens)
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
generated_tokens=tokens,
)
yield page

View File

@ -43,9 +43,11 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
self.device = decide_device(accelerator_options.device)
if self.device=="mlx":
_log.warning(f"Mapping mlx to cpu for AutoModelForCausalLM")
self.device = cpu
if self.device == "mlx":
_log.warning(
"Mapping mlx to cpu for AutoModelForCausalLM, use MLX framework!"
)
self.device = "cpu"
self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens
@ -53,7 +55,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
_log.debug(f"Available device for VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None:
artifacts_path = HuggingFaceVlmModel.download_models(
self.vlm_options.repo_id
@ -117,8 +118,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
@ -157,14 +157,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=response)
yield page
@ -172,11 +164,12 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
# prompt = f"{user_prompt}<|image_1|>Convert this image into MarkDown and only return the bare MarkDown!{prompt_suffix}{assistant_prompt}"
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")

View File

@ -38,10 +38,8 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
BitsAndBytesConfig,
)
device = decide_device(accelerator_options.device)
self.device = device
_log.debug(f"Available device for HuggingFace VLM: {device}")
self.device = decide_device(accelerator_options.device)
_log.debug(f"Available device for HuggingFace VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
@ -54,7 +52,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
# self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
@ -68,7 +66,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
if not self.param_quantized:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
device_map=self.device,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
@ -82,7 +80,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
else:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
device_map=self.device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
_attn_implementation=(
@ -94,29 +92,6 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
# trust_remote_code=True,
) # .to(self.device)
"""
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
"""
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -128,8 +103,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
@ -141,22 +115,9 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.param_question},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
# Define prompt structure
prompt = self.formulate_prompt()
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
)
@ -180,14 +141,26 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.vlm_options.prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
return prompt

View File

@ -39,10 +39,15 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
)
self.device = decide_device(accelerator_options.device)
self.device = "cpu" # FIXME
self.use_cache = True
self.max_new_tokens = 64 # FIXME
if self.device == "mlx":
_log.warning(
"Mapping mlx to cpu for LlavaForConditionalGeneration, use MLX framework!"
)
self.device = "cpu"
self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens
_log.debug(f"Available device for VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
@ -54,9 +59,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
model_path = artifacts_path
_log.debug(f"model: {model_path}")
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=self.trust_remote_code,
@ -98,12 +100,11 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
images = [hi_res_image]
# Define prompt structure
# prompt = "<s>[INST]Describe the images.\n[IMG][/INST]"
prompt = self.formulate_prompt()
inputs = self.processor(
text=prompt, images=images, return_tensors="pt"
).to(self.device) # .to("cuda")
).to(self.device)
# Generate response
start_time = time.time()
@ -113,8 +114,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
use_cache=self.use_cache, # Enables KV caching which can improve performance
)
print(generate_ids)
num_tokens = len(generate_ids[0])
generation_time = time.time() - start_time
@ -124,9 +123,11 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
clean_up_tokenization_spaces=False,
)[0]
page.predictions.vlm_response = VlmPrediction(text=response,
generated_tokens=num_tokens,
generation_time=generation_time)
page.predictions.vlm_response = VlmPrediction(
text=response,
generated_tokens=num_tokens,
generation_time=generation_time,
)
yield page

View File

@ -13,11 +13,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import (
from docling.datamodel.pipeline_model_specializations import (
ApiVlmOptions,
HuggingFaceVlmOptions,
InferenceFramework,
ResponseFormat,
)
from docling.datamodel.pipeline_options import (
VlmPipelineOptions,
)
from docling.datamodel.settings import settings

View File

@ -2,10 +2,12 @@ import logging
from pathlib import Path
from typing import Optional
from docling.datamodel.pipeline_options import (
granite_picture_description,
from docling.datamodel.pipeline_model_specializations import (
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.datamodel.pipeline_options import (
granite_picture_description,
smolvlm_picture_description,
)
from docling.datamodel.settings import settings

View File

@ -11,10 +11,15 @@ from docling.datamodel.pipeline_options import (
InferenceFramework,
ResponseFormat,
VlmPipelineOptions,
smoldocling_vlm_mlx_conversion_options,
smoldocling_vlm_conversion_options,
granite_vision_vlm_conversion_options,
granite_vision_vlm_mlx_conversion_options,
granite_vision_vlm_ollama_conversion_options,
phi_vlm_conversion_options,
pixtral_12b_vlm_conversion_options,
pixtral_12b_vlm_mlx_conversion_options,
qwen25_vl_3b_vlm_mlx_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline
@ -28,6 +33,7 @@ sources = [
pipeline_options = VlmPipelineOptions()
# If force_backend_text = True, text from backend will be used instead of generated text
pipeline_options.force_backend_text = False
pipeline_options.generate_page_images = True
## On GPU systems, enable flash_attention_2 with CUDA:
# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA
@ -37,11 +43,13 @@ pipeline_options.force_backend_text = False
# pipeline_options.vlm_options = smoldocling_vlm_conversion_options
## Pick a VLM model. Fast Apple Silicon friendly implementation for SmolDocling-256M via MLX
pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
# pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
## Alternative VLM models:
# pipeline_options.vlm_options = granite_vision_vlm_conversion_options
pipeline_options.vlm_options = phi_vlm_conversion_options
"""
pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="mistralai/Pixtral-12B-Base-2409",
@ -105,7 +113,7 @@ converter = DocumentConverter(
pipeline_cls=VlmPipeline,
pipeline_options=pipeline_options,
),
}
},
)
out_path = Path("scratch")
@ -121,38 +129,43 @@ for source in sources:
res = converter.convert(source)
print("")
#print(res.document.export_to_markdown())
# print(res.document.export_to_markdown())
for i,page in enumerate(res.pages):
model_id = pipeline_options.vlm_options.repo_id.replace("/", "_")
fname = f"{model_id}-{res.input.file.stem}"
for i, page in enumerate(res.pages):
print("")
print(f" ---------- Predicted page {i} in {pipeline_options.vlm_options.response_format}:")
print(
f" ---------- Predicted page {i} in {pipeline_options.vlm_options.response_format}:"
)
print(page.predictions.vlm_response.text)
print(f" ---------- ")
print(" ---------- ")
print("===== Final output of the converted document =======")
with (out_path / f"{res.input.file.stem}.json").open("w") as fp:
with (out_path / f"{fname}.json").open("w") as fp:
fp.write(json.dumps(res.document.export_to_dict()))
res.document.save_as_json(
out_path / f"{res.input.file.stem}.json",
out_path / f"{fname}.json",
image_mode=ImageRefMode.PLACEHOLDER,
)
print(f" => produced {out_path / res.input.file.stem}.json")
print(f" => produced {out_path / fname}.json")
res.document.save_as_markdown(
out_path / f"{res.input.file.stem}.md",
out_path / f"{fname}.md",
image_mode=ImageRefMode.PLACEHOLDER,
)
print(f" => produced {out_path / res.input.file.stem}.md")
print(f" => produced {out_path / fname}.md")
res.document.save_as_html(
out_path / f"{res.input.file.stem}.html",
out_path / f"{fname}.html",
image_mode=ImageRefMode.EMBEDDED,
labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE],
# split_page_view=True,
split_page_view=True,
)
print(f" => produced {out_path / res.input.file.stem}.html")
print(f" => produced {out_path / fname}.html")
pg_num = res.document.num_pages()
print("")
@ -161,4 +174,3 @@ for source in sources:
f"Total document prediction time: {inference_time:.2f} seconds, pages: {pg_num}"
)
print("====================================================")