mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
fixed the MyPy
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
a3716b1961
commit
7c67d2b2fe
@ -29,6 +29,13 @@ from docling.datamodel.base_models import (
|
|||||||
OutputFormat,
|
OutputFormat,
|
||||||
)
|
)
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_model_specializations import (
|
||||||
|
VlmModelType,
|
||||||
|
granite_vision_vlm_conversion_options,
|
||||||
|
granite_vision_vlm_ollama_conversion_options,
|
||||||
|
smoldocling_vlm_conversion_options,
|
||||||
|
smoldocling_vlm_mlx_conversion_options,
|
||||||
|
)
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
AcceleratorDevice,
|
AcceleratorDevice,
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
@ -39,12 +46,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
PdfPipeline,
|
PdfPipeline,
|
||||||
PdfPipelineOptions,
|
PdfPipelineOptions,
|
||||||
TableFormerMode,
|
TableFormerMode,
|
||||||
VlmModelType,
|
|
||||||
VlmPipelineOptions,
|
VlmPipelineOptions,
|
||||||
granite_vision_vlm_conversion_options,
|
|
||||||
granite_vision_vlm_ollama_conversion_options,
|
|
||||||
smoldocling_vlm_conversion_options,
|
|
||||||
smoldocling_vlm_mlx_conversion_options,
|
|
||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
|
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
from docling_core.types.io import (
|
|
||||||
DocumentStream,
|
|
||||||
)
|
|
||||||
|
|
||||||
from docling_core.types.doc import (
|
from docling_core.types.doc import (
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
DocItemLabel,
|
DocItemLabel,
|
||||||
@ -14,6 +10,9 @@ from docling_core.types.doc import (
|
|||||||
TableCell,
|
TableCell,
|
||||||
)
|
)
|
||||||
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
|
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
|
||||||
|
from docling_core.types.io import (
|
||||||
|
DocumentStream,
|
||||||
|
)
|
||||||
|
|
||||||
# DO NOT REMOVE; explicitly exposed from this location
|
# DO NOT REMOVE; explicitly exposed from this location
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
@ -148,11 +147,13 @@ class BasePageElement(BaseModel):
|
|||||||
class LayoutPrediction(BaseModel):
|
class LayoutPrediction(BaseModel):
|
||||||
clusters: List[Cluster] = []
|
clusters: List[Cluster] = []
|
||||||
|
|
||||||
|
|
||||||
class VlmPredictionToken(BaseModel):
|
class VlmPredictionToken(BaseModel):
|
||||||
text: str = ""
|
text: str = ""
|
||||||
token: int = -1
|
token: int = -1
|
||||||
logprob: float = -1
|
logprob: float = -1
|
||||||
|
|
||||||
|
|
||||||
class VlmPrediction(BaseModel):
|
class VlmPrediction(BaseModel):
|
||||||
text: str = ""
|
text: str = ""
|
||||||
generated_tokens: list[VlmPredictionToken] = []
|
generated_tokens: list[VlmPredictionToken] = []
|
||||||
|
@ -16,6 +16,12 @@ from pydantic import (
|
|||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
|
from docling.datamodel.pipeline_model_specializations import (
|
||||||
|
ApiVlmOptions,
|
||||||
|
HuggingFaceVlmOptions,
|
||||||
|
smoldocling_vlm_conversion_options,
|
||||||
|
)
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -121,24 +127,22 @@ class RapidOcrOptions(OcrOptions):
|
|||||||
lang: List[str] = [
|
lang: List[str] = [
|
||||||
"english",
|
"english",
|
||||||
"chinese",
|
"chinese",
|
||||||
] # However, language as a parameter is not supported by rapidocr yet and hence changing this options doesn't affect anything.
|
]
|
||||||
# For more details on supported languages by RapidOCR visit https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
|
# However, language as a parameter is not supported by rapidocr yet
|
||||||
|
# and hence changing this options doesn't affect anything.
|
||||||
|
|
||||||
|
# For more details on supported languages by RapidOCR visit
|
||||||
|
# https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
|
||||||
|
|
||||||
|
# For more details on the following options visit
|
||||||
|
# https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
|
||||||
|
|
||||||
# For more details on the following options visit https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
|
|
||||||
text_score: float = 0.5 # same default as rapidocr
|
text_score: float = 0.5 # same default as rapidocr
|
||||||
|
|
||||||
use_det: Optional[bool] = None # same default as rapidocr
|
use_det: Optional[bool] = None # same default as rapidocr
|
||||||
use_cls: Optional[bool] = None # same default as rapidocr
|
use_cls: Optional[bool] = None # same default as rapidocr
|
||||||
use_rec: Optional[bool] = None # same default as rapidocr
|
use_rec: Optional[bool] = None # same default as rapidocr
|
||||||
|
|
||||||
# class Device(Enum):
|
|
||||||
# CPU = "CPU"
|
|
||||||
# CUDA = "CUDA"
|
|
||||||
# DIRECTML = "DIRECTML"
|
|
||||||
# AUTO = "AUTO"
|
|
||||||
|
|
||||||
# device: Device = Device.AUTO # Default value is AUTO
|
|
||||||
|
|
||||||
print_verbose: bool = False # same default as rapidocr
|
print_verbose: bool = False # same default as rapidocr
|
||||||
|
|
||||||
det_model_path: Optional[str] = None # same default as rapidocr
|
det_model_path: Optional[str] = None # same default as rapidocr
|
||||||
@ -243,110 +247,18 @@ class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
|
|||||||
return self.repo_id.replace("/", "--")
|
return self.repo_id.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
|
# SmolVLM
|
||||||
smolvlm_picture_description = PictureDescriptionVlmOptions(
|
smolvlm_picture_description = PictureDescriptionVlmOptions(
|
||||||
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"
|
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"
|
||||||
)
|
)
|
||||||
# phi_picture_description = PictureDescriptionVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct")
|
|
||||||
|
# GraniteVision
|
||||||
granite_picture_description = PictureDescriptionVlmOptions(
|
granite_picture_description = PictureDescriptionVlmOptions(
|
||||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
||||||
prompt="What is shown in this image?",
|
prompt="What is shown in this image?",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseVlmOptions(BaseModel):
|
|
||||||
kind: str
|
|
||||||
prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(str, Enum):
|
|
||||||
DOCTAGS = "doctags"
|
|
||||||
MARKDOWN = "markdown"
|
|
||||||
HTML = "html"
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceFramework(str, Enum):
|
|
||||||
MLX = "mlx"
|
|
||||||
TRANSFORMERS = "transformers"
|
|
||||||
OPENAI = "openai"
|
|
||||||
TRANSFORMERS_AutoModelForVision2Seq = "transformers-AutoModelForVision2Seq"
|
|
||||||
TRANSFORMERS_AutoModelForCausalLM = "transformers-AutoModelForCausalLM"
|
|
||||||
TRANSFORMERS_LlavaForConditionalGeneration = (
|
|
||||||
"transformers-LlavaForConditionalGeneration"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceVlmOptions(BaseVlmOptions):
|
|
||||||
kind: Literal["hf_model_options"] = "hf_model_options"
|
|
||||||
|
|
||||||
repo_id: str
|
|
||||||
load_in_8bit: bool = True
|
|
||||||
llm_int8_threshold: float = 6.0
|
|
||||||
quantized: bool = False
|
|
||||||
|
|
||||||
inference_framework: InferenceFramework
|
|
||||||
response_format: ResponseFormat
|
|
||||||
|
|
||||||
scale: float = 2.0
|
|
||||||
|
|
||||||
use_kv_cache: bool = True
|
|
||||||
max_new_tokens: int = 4096
|
|
||||||
|
|
||||||
@property
|
|
||||||
def repo_cache_folder(self) -> str:
|
|
||||||
return self.repo_id.replace("/", "--")
|
|
||||||
|
|
||||||
|
|
||||||
class ApiVlmOptions(BaseVlmOptions):
|
|
||||||
kind: Literal["api_model_options"] = "api_model_options"
|
|
||||||
|
|
||||||
url: AnyUrl = AnyUrl(
|
|
||||||
"http://localhost:11434/v1/chat/completions"
|
|
||||||
) # Default to ollama
|
|
||||||
headers: Dict[str, str] = {}
|
|
||||||
params: Dict[str, Any] = {}
|
|
||||||
scale: float = 2.0
|
|
||||||
timeout: float = 60
|
|
||||||
response_format: ResponseFormat
|
|
||||||
|
|
||||||
|
|
||||||
smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
|
|
||||||
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
|
|
||||||
prompt="Convert this page to docling.",
|
|
||||||
response_format=ResponseFormat.DOCTAGS,
|
|
||||||
inference_framework=InferenceFramework.MLX,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
|
|
||||||
repo_id="ds4sd/SmolDocling-256M-preview",
|
|
||||||
prompt="Convert this page to docling.",
|
|
||||||
response_format=ResponseFormat.DOCTAGS,
|
|
||||||
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
|
|
||||||
)
|
|
||||||
|
|
||||||
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
|
||||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
|
||||||
prompt="OCR the full page to markdown.",
|
|
||||||
response_format=ResponseFormat.MARKDOWN,
|
|
||||||
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
|
|
||||||
)
|
|
||||||
|
|
||||||
granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
|
|
||||||
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
|
|
||||||
params={"model": "granite3.2-vision:2b"},
|
|
||||||
prompt="OCR the full page to markdown.",
|
|
||||||
scale=1.0,
|
|
||||||
timeout=120,
|
|
||||||
response_format=ResponseFormat.MARKDOWN,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VlmModelType(str, Enum):
|
|
||||||
SMOLDOCLING = "smoldocling"
|
|
||||||
GRANITE_VISION = "granite_vision"
|
|
||||||
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
|
||||||
|
|
||||||
|
|
||||||
# Define an enum for the backend options
|
# Define an enum for the backend options
|
||||||
class PdfBackend(str, Enum):
|
class PdfBackend(str, Enum):
|
||||||
"""Enum of valid PDF backends."""
|
"""Enum of valid PDF backends."""
|
||||||
|
@ -94,7 +94,7 @@ class HuggingFaceMlxModel(BasePageModel):
|
|||||||
_log.debug("start generating ...")
|
_log.debug("start generating ...")
|
||||||
|
|
||||||
# Call model to generate:
|
# Call model to generate:
|
||||||
tokens:list[VlmPredictionToken] = []
|
tokens: list[VlmPredictionToken] = []
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
for token in self.stream_generate(
|
for token in self.stream_generate(
|
||||||
@ -105,14 +105,25 @@ class HuggingFaceMlxModel(BasePageModel):
|
|||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
if len(token.logprobs.shape)==1:
|
if len(token.logprobs.shape) == 1:
|
||||||
tokens.append(VlmPredictionToken(text=token.text,
|
tokens.append(
|
||||||
|
VlmPredictionToken(
|
||||||
|
text=token.text,
|
||||||
token=token.token,
|
token=token.token,
|
||||||
logprob=token.logprobs[token.token]))
|
logprob=token.logprobs[token.token],
|
||||||
elif len(token.logprobs.shape)==2 and token.logprobs.shape[0]==1:
|
)
|
||||||
tokens.append(VlmPredictionToken(text=token.text,
|
)
|
||||||
|
elif (
|
||||||
|
len(token.logprobs.shape) == 2
|
||||||
|
and token.logprobs.shape[0] == 1
|
||||||
|
):
|
||||||
|
tokens.append(
|
||||||
|
VlmPredictionToken(
|
||||||
|
text=token.text,
|
||||||
token=token.token,
|
token=token.token,
|
||||||
logprob=token.logprobs[0, token.token]))
|
logprob=token.logprobs[0, token.token],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
output += token.text
|
output += token.text
|
||||||
if "</doctag>" in token.text:
|
if "</doctag>" in token.text:
|
||||||
@ -121,9 +132,13 @@ class HuggingFaceMlxModel(BasePageModel):
|
|||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
page_tags = output
|
page_tags = output
|
||||||
|
|
||||||
_log.debug(f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens)/generation_time} tokens/sec).")
|
_log.debug(
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags,
|
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
||||||
|
)
|
||||||
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
|
text=page_tags,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
generated_tokens=tokens)
|
generated_tokens=tokens,
|
||||||
|
)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
@ -43,9 +43,11 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
|
|
||||||
self.device = decide_device(accelerator_options.device)
|
self.device = decide_device(accelerator_options.device)
|
||||||
|
|
||||||
if self.device=="mlx":
|
if self.device == "mlx":
|
||||||
_log.warning(f"Mapping mlx to cpu for AutoModelForCausalLM")
|
_log.warning(
|
||||||
self.device = cpu
|
"Mapping mlx to cpu for AutoModelForCausalLM, use MLX framework!"
|
||||||
|
)
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
self.use_cache = vlm_options.use_kv_cache
|
self.use_cache = vlm_options.use_kv_cache
|
||||||
self.max_new_tokens = vlm_options.max_new_tokens
|
self.max_new_tokens = vlm_options.max_new_tokens
|
||||||
@ -53,7 +55,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
_log.debug(f"Available device for VLM: {self.device}")
|
_log.debug(f"Available device for VLM: {self.device}")
|
||||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
|
|
||||||
# PARAMETERS:
|
|
||||||
if artifacts_path is None:
|
if artifacts_path is None:
|
||||||
artifacts_path = HuggingFaceVlmModel.download_models(
|
artifacts_path = HuggingFaceVlmModel.download_models(
|
||||||
self.vlm_options.repo_id
|
self.vlm_options.repo_id
|
||||||
@ -117,8 +118,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
with TimeRecorder(conv_res, "vlm"):
|
with TimeRecorder(conv_res, "vlm"):
|
||||||
assert page.size is not None
|
assert page.size is not None
|
||||||
|
|
||||||
hi_res_image = page.get_image(scale=self.vlm_options.scale) # 144dpi
|
hi_res_image = page.get_image(scale=self.vlm_options.scale)
|
||||||
# hi_res_image = page.get_image(scale=1.0) # 72dpi
|
|
||||||
|
|
||||||
if hi_res_image is not None:
|
if hi_res_image is not None:
|
||||||
im_width, im_height = hi_res_image.size
|
im_width, im_height = hi_res_image.size
|
||||||
@ -157,14 +157,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||||
)
|
)
|
||||||
|
|
||||||
# inference_time = time.time() - start_time
|
|
||||||
# tokens_per_second = num_tokens / generation_time
|
|
||||||
# print("")
|
|
||||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
|
||||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
|
||||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
|
||||||
# print("")
|
|
||||||
page.predictions.vlm_response = VlmPrediction(text=response)
|
page.predictions.vlm_response = VlmPrediction(text=response)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
@ -172,11 +164,12 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
def formulate_prompt(self) -> str:
|
def formulate_prompt(self) -> str:
|
||||||
"""Formulate a prompt for the VLM."""
|
"""Formulate a prompt for the VLM."""
|
||||||
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||||
|
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
||||||
|
|
||||||
user_prompt = "<|user|>"
|
user_prompt = "<|user|>"
|
||||||
assistant_prompt = "<|assistant|>"
|
assistant_prompt = "<|assistant|>"
|
||||||
prompt_suffix = "<|end|>"
|
prompt_suffix = "<|end|>"
|
||||||
|
|
||||||
# prompt = f"{user_prompt}<|image_1|>Convert this image into MarkDown and only return the bare MarkDown!{prompt_suffix}{assistant_prompt}"
|
|
||||||
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
|
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
|
||||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||||
|
|
||||||
|
@ -38,10 +38,8 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
device = decide_device(accelerator_options.device)
|
self.device = decide_device(accelerator_options.device)
|
||||||
self.device = device
|
_log.debug(f"Available device for HuggingFace VLM: {self.device}")
|
||||||
|
|
||||||
_log.debug(f"Available device for HuggingFace VLM: {device}")
|
|
||||||
|
|
||||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
|
|
||||||
@ -54,7 +52,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
elif (artifacts_path / repo_cache_folder).exists():
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
artifacts_path = artifacts_path / repo_cache_folder
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
# self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
||||||
self.param_quantization_config = BitsAndBytesConfig(
|
self.param_quantization_config = BitsAndBytesConfig(
|
||||||
load_in_8bit=vlm_options.load_in_8bit, # True,
|
load_in_8bit=vlm_options.load_in_8bit, # True,
|
||||||
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
|
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
|
||||||
@ -68,7 +66,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
if not self.param_quantized:
|
if not self.param_quantized:
|
||||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
device_map=device,
|
device_map=self.device,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
@ -82,7 +80,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
else:
|
else:
|
||||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
device_map=device,
|
device_map=self.device,
|
||||||
torch_dtype="auto",
|
torch_dtype="auto",
|
||||||
quantization_config=self.param_quantization_config,
|
quantization_config=self.param_quantization_config,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
@ -94,29 +92,6 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
# trust_remote_code=True,
|
# trust_remote_code=True,
|
||||||
) # .to(self.device)
|
) # .to(self.device)
|
||||||
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def download_models(
|
|
||||||
repo_id: str,
|
|
||||||
local_dir: Optional[Path] = None,
|
|
||||||
force: bool = False,
|
|
||||||
progress: bool = False,
|
|
||||||
) -> Path:
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from huggingface_hub.utils import disable_progress_bars
|
|
||||||
|
|
||||||
if not progress:
|
|
||||||
disable_progress_bars()
|
|
||||||
download_path = snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
force_download=force,
|
|
||||||
local_dir=local_dir,
|
|
||||||
# revision="v0.0.1",
|
|
||||||
)
|
|
||||||
|
|
||||||
return Path(download_path)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
@ -128,8 +103,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
with TimeRecorder(conv_res, "vlm"):
|
with TimeRecorder(conv_res, "vlm"):
|
||||||
assert page.size is not None
|
assert page.size is not None
|
||||||
|
|
||||||
hi_res_image = page.get_image(scale=2.0) # 144dpi
|
hi_res_image = page.get_image(scale=self.vlm_options.scale)
|
||||||
# hi_res_image = page.get_image(scale=1.0) # 72dpi
|
|
||||||
|
|
||||||
if hi_res_image is not None:
|
if hi_res_image is not None:
|
||||||
im_width, im_height = hi_res_image.size
|
im_width, im_height = hi_res_image.size
|
||||||
@ -141,22 +115,9 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
if hi_res_image.mode != "RGB":
|
if hi_res_image.mode != "RGB":
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
hi_res_image = hi_res_image.convert("RGB")
|
||||||
|
|
||||||
messages = [
|
# Define prompt structure
|
||||||
{
|
prompt = self.formulate_prompt()
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "This is a page from a document.",
|
|
||||||
},
|
|
||||||
{"type": "image"},
|
|
||||||
{"type": "text", "text": self.param_question},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
prompt = self.processor.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=False
|
|
||||||
)
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
text=prompt, images=[hi_res_image], return_tensors="pt"
|
||||||
)
|
)
|
||||||
@ -180,14 +141,26 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
|||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||||
)
|
)
|
||||||
|
|
||||||
# inference_time = time.time() - start_time
|
|
||||||
# tokens_per_second = num_tokens / generation_time
|
|
||||||
# print("")
|
|
||||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
|
||||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
|
||||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
|
||||||
# print("")
|
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
|
def formulate_prompt(self) -> str:
|
||||||
|
"""Formulate a prompt for the VLM."""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "This is a page from a document.",
|
||||||
|
},
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": self.vlm_options.prompt},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
prompt = self.processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
@ -39,10 +39,15 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.device = decide_device(accelerator_options.device)
|
self.device = decide_device(accelerator_options.device)
|
||||||
self.device = "cpu" # FIXME
|
|
||||||
|
|
||||||
self.use_cache = True
|
if self.device == "mlx":
|
||||||
self.max_new_tokens = 64 # FIXME
|
_log.warning(
|
||||||
|
"Mapping mlx to cpu for LlavaForConditionalGeneration, use MLX framework!"
|
||||||
|
)
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
self.use_cache = vlm_options.use_kv_cache
|
||||||
|
self.max_new_tokens = vlm_options.max_new_tokens
|
||||||
|
|
||||||
_log.debug(f"Available device for VLM: {self.device}")
|
_log.debug(f"Available device for VLM: {self.device}")
|
||||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
@ -54,9 +59,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
elif (artifacts_path / repo_cache_folder).exists():
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
artifacts_path = artifacts_path / repo_cache_folder
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
model_path = artifacts_path
|
|
||||||
_log.debug(f"model: {model_path}")
|
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
@ -98,12 +100,11 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
images = [hi_res_image]
|
images = [hi_res_image]
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
# prompt = "<s>[INST]Describe the images.\n[IMG][/INST]"
|
|
||||||
prompt = self.formulate_prompt()
|
prompt = self.formulate_prompt()
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
text=prompt, images=images, return_tensors="pt"
|
text=prompt, images=images, return_tensors="pt"
|
||||||
).to(self.device) # .to("cuda")
|
).to(self.device)
|
||||||
|
|
||||||
# Generate response
|
# Generate response
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -113,8 +114,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
use_cache=self.use_cache, # Enables KV caching which can improve performance
|
use_cache=self.use_cache, # Enables KV caching which can improve performance
|
||||||
)
|
)
|
||||||
|
|
||||||
print(generate_ids)
|
|
||||||
|
|
||||||
num_tokens = len(generate_ids[0])
|
num_tokens = len(generate_ids[0])
|
||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
@ -124,9 +123,11 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
page.predictions.vlm_response = VlmPrediction(text=response,
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
|
text=response,
|
||||||
generated_tokens=num_tokens,
|
generated_tokens=num_tokens,
|
||||||
generation_time=generation_time)
|
generation_time=generation_time,
|
||||||
|
)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
|
@ -13,11 +13,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend
|
|||||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||||
from docling.datamodel.base_models import InputFormat, Page
|
from docling.datamodel.base_models import InputFormat, Page
|
||||||
from docling.datamodel.document import ConversionResult, InputDocument
|
from docling.datamodel.document import ConversionResult, InputDocument
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_model_specializations import (
|
||||||
ApiVlmOptions,
|
ApiVlmOptions,
|
||||||
HuggingFaceVlmOptions,
|
HuggingFaceVlmOptions,
|
||||||
InferenceFramework,
|
InferenceFramework,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
)
|
||||||
|
from docling.datamodel.pipeline_options import (
|
||||||
VlmPipelineOptions,
|
VlmPipelineOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
|
@ -2,10 +2,12 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_model_specializations import (
|
||||||
granite_picture_description,
|
|
||||||
smoldocling_vlm_conversion_options,
|
smoldocling_vlm_conversion_options,
|
||||||
smoldocling_vlm_mlx_conversion_options,
|
smoldocling_vlm_mlx_conversion_options,
|
||||||
|
)
|
||||||
|
from docling.datamodel.pipeline_options import (
|
||||||
|
granite_picture_description,
|
||||||
smolvlm_picture_description,
|
smolvlm_picture_description,
|
||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
|
@ -11,10 +11,15 @@ from docling.datamodel.pipeline_options import (
|
|||||||
InferenceFramework,
|
InferenceFramework,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
VlmPipelineOptions,
|
VlmPipelineOptions,
|
||||||
smoldocling_vlm_mlx_conversion_options,
|
|
||||||
smoldocling_vlm_conversion_options,
|
|
||||||
granite_vision_vlm_conversion_options,
|
granite_vision_vlm_conversion_options,
|
||||||
|
granite_vision_vlm_mlx_conversion_options,
|
||||||
granite_vision_vlm_ollama_conversion_options,
|
granite_vision_vlm_ollama_conversion_options,
|
||||||
|
phi_vlm_conversion_options,
|
||||||
|
pixtral_12b_vlm_conversion_options,
|
||||||
|
pixtral_12b_vlm_mlx_conversion_options,
|
||||||
|
qwen25_vl_3b_vlm_mlx_conversion_options,
|
||||||
|
smoldocling_vlm_conversion_options,
|
||||||
|
smoldocling_vlm_mlx_conversion_options,
|
||||||
)
|
)
|
||||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||||
@ -28,6 +33,7 @@ sources = [
|
|||||||
pipeline_options = VlmPipelineOptions()
|
pipeline_options = VlmPipelineOptions()
|
||||||
# If force_backend_text = True, text from backend will be used instead of generated text
|
# If force_backend_text = True, text from backend will be used instead of generated text
|
||||||
pipeline_options.force_backend_text = False
|
pipeline_options.force_backend_text = False
|
||||||
|
pipeline_options.generate_page_images = True
|
||||||
|
|
||||||
## On GPU systems, enable flash_attention_2 with CUDA:
|
## On GPU systems, enable flash_attention_2 with CUDA:
|
||||||
# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA
|
# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA
|
||||||
@ -37,11 +43,13 @@ pipeline_options.force_backend_text = False
|
|||||||
# pipeline_options.vlm_options = smoldocling_vlm_conversion_options
|
# pipeline_options.vlm_options = smoldocling_vlm_conversion_options
|
||||||
|
|
||||||
## Pick a VLM model. Fast Apple Silicon friendly implementation for SmolDocling-256M via MLX
|
## Pick a VLM model. Fast Apple Silicon friendly implementation for SmolDocling-256M via MLX
|
||||||
pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
|
# pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
|
||||||
|
|
||||||
## Alternative VLM models:
|
## Alternative VLM models:
|
||||||
# pipeline_options.vlm_options = granite_vision_vlm_conversion_options
|
# pipeline_options.vlm_options = granite_vision_vlm_conversion_options
|
||||||
|
|
||||||
|
pipeline_options.vlm_options = phi_vlm_conversion_options
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
|
pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||||
repo_id="mistralai/Pixtral-12B-Base-2409",
|
repo_id="mistralai/Pixtral-12B-Base-2409",
|
||||||
@ -105,7 +113,7 @@ converter = DocumentConverter(
|
|||||||
pipeline_cls=VlmPipeline,
|
pipeline_cls=VlmPipeline,
|
||||||
pipeline_options=pipeline_options,
|
pipeline_options=pipeline_options,
|
||||||
),
|
),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
out_path = Path("scratch")
|
out_path = Path("scratch")
|
||||||
@ -121,38 +129,43 @@ for source in sources:
|
|||||||
res = converter.convert(source)
|
res = converter.convert(source)
|
||||||
|
|
||||||
print("")
|
print("")
|
||||||
#print(res.document.export_to_markdown())
|
# print(res.document.export_to_markdown())
|
||||||
|
|
||||||
for i,page in enumerate(res.pages):
|
model_id = pipeline_options.vlm_options.repo_id.replace("/", "_")
|
||||||
|
fname = f"{model_id}-{res.input.file.stem}"
|
||||||
|
|
||||||
|
for i, page in enumerate(res.pages):
|
||||||
print("")
|
print("")
|
||||||
print(f" ---------- Predicted page {i} in {pipeline_options.vlm_options.response_format}:")
|
print(
|
||||||
|
f" ---------- Predicted page {i} in {pipeline_options.vlm_options.response_format}:"
|
||||||
|
)
|
||||||
print(page.predictions.vlm_response.text)
|
print(page.predictions.vlm_response.text)
|
||||||
print(f" ---------- ")
|
print(" ---------- ")
|
||||||
|
|
||||||
print("===== Final output of the converted document =======")
|
print("===== Final output of the converted document =======")
|
||||||
|
|
||||||
with (out_path / f"{res.input.file.stem}.json").open("w") as fp:
|
with (out_path / f"{fname}.json").open("w") as fp:
|
||||||
fp.write(json.dumps(res.document.export_to_dict()))
|
fp.write(json.dumps(res.document.export_to_dict()))
|
||||||
|
|
||||||
res.document.save_as_json(
|
res.document.save_as_json(
|
||||||
out_path / f"{res.input.file.stem}.json",
|
out_path / f"{fname}.json",
|
||||||
image_mode=ImageRefMode.PLACEHOLDER,
|
image_mode=ImageRefMode.PLACEHOLDER,
|
||||||
)
|
)
|
||||||
print(f" => produced {out_path / res.input.file.stem}.json")
|
print(f" => produced {out_path / fname}.json")
|
||||||
|
|
||||||
res.document.save_as_markdown(
|
res.document.save_as_markdown(
|
||||||
out_path / f"{res.input.file.stem}.md",
|
out_path / f"{fname}.md",
|
||||||
image_mode=ImageRefMode.PLACEHOLDER,
|
image_mode=ImageRefMode.PLACEHOLDER,
|
||||||
)
|
)
|
||||||
print(f" => produced {out_path / res.input.file.stem}.md")
|
print(f" => produced {out_path / fname}.md")
|
||||||
|
|
||||||
res.document.save_as_html(
|
res.document.save_as_html(
|
||||||
out_path / f"{res.input.file.stem}.html",
|
out_path / f"{fname}.html",
|
||||||
image_mode=ImageRefMode.EMBEDDED,
|
image_mode=ImageRefMode.EMBEDDED,
|
||||||
labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE],
|
labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE],
|
||||||
# split_page_view=True,
|
split_page_view=True,
|
||||||
)
|
)
|
||||||
print(f" => produced {out_path / res.input.file.stem}.html")
|
print(f" => produced {out_path / fname}.html")
|
||||||
|
|
||||||
pg_num = res.document.num_pages()
|
pg_num = res.document.num_pages()
|
||||||
print("")
|
print("")
|
||||||
@ -161,4 +174,3 @@ for source in sources:
|
|||||||
f"Total document prediction time: {inference_time:.2f} seconds, pages: {pg_num}"
|
f"Total document prediction time: {inference_time:.2f} seconds, pages: {pg_num}"
|
||||||
)
|
)
|
||||||
print("====================================================")
|
print("====================================================")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user