mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
fixed the MyPy
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
a3716b1961
commit
7c67d2b2fe
@ -29,6 +29,13 @@ from docling.datamodel.base_models import (
|
||||
OutputFormat,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_model_specializations import (
|
||||
VlmModelType,
|
||||
granite_vision_vlm_conversion_options,
|
||||
granite_vision_vlm_ollama_conversion_options,
|
||||
smoldocling_vlm_conversion_options,
|
||||
smoldocling_vlm_mlx_conversion_options,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
@ -39,12 +46,7 @@ from docling.datamodel.pipeline_options import (
|
||||
PdfPipeline,
|
||||
PdfPipelineOptions,
|
||||
TableFormerMode,
|
||||
VlmModelType,
|
||||
VlmPipelineOptions,
|
||||
granite_vision_vlm_conversion_options,
|
||||
granite_vision_vlm_ollama_conversion_options,
|
||||
smoldocling_vlm_conversion_options,
|
||||
smoldocling_vlm_mlx_conversion_options,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
|
||||
|
@ -1,10 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from docling_core.types.io import (
|
||||
DocumentStream,
|
||||
)
|
||||
|
||||
from docling_core.types.doc import (
|
||||
BoundingBox,
|
||||
DocItemLabel,
|
||||
@ -14,6 +10,9 @@ from docling_core.types.doc import (
|
||||
TableCell,
|
||||
)
|
||||
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
|
||||
from docling_core.types.io import (
|
||||
DocumentStream,
|
||||
)
|
||||
|
||||
# DO NOT REMOVE; explicitly exposed from this location
|
||||
from PIL.Image import Image
|
||||
@ -148,11 +147,13 @@ class BasePageElement(BaseModel):
|
||||
class LayoutPrediction(BaseModel):
|
||||
clusters: List[Cluster] = []
|
||||
|
||||
|
||||
class VlmPredictionToken(BaseModel):
|
||||
text: str = ""
|
||||
token: int = -1
|
||||
logprob: float = -1
|
||||
|
||||
|
||||
|
||||
class VlmPrediction(BaseModel):
|
||||
text: str = ""
|
||||
generated_tokens: list[VlmPredictionToken] = []
|
||||
|
@ -16,6 +16,12 @@ from pydantic import (
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from docling.datamodel.pipeline_model_specializations import (
|
||||
ApiVlmOptions,
|
||||
HuggingFaceVlmOptions,
|
||||
smoldocling_vlm_conversion_options,
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -121,24 +127,22 @@ class RapidOcrOptions(OcrOptions):
|
||||
lang: List[str] = [
|
||||
"english",
|
||||
"chinese",
|
||||
] # However, language as a parameter is not supported by rapidocr yet and hence changing this options doesn't affect anything.
|
||||
# For more details on supported languages by RapidOCR visit https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
|
||||
]
|
||||
# However, language as a parameter is not supported by rapidocr yet
|
||||
# and hence changing this options doesn't affect anything.
|
||||
|
||||
# For more details on supported languages by RapidOCR visit
|
||||
# https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
|
||||
|
||||
# For more details on the following options visit
|
||||
# https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
|
||||
|
||||
# For more details on the following options visit https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
|
||||
text_score: float = 0.5 # same default as rapidocr
|
||||
|
||||
use_det: Optional[bool] = None # same default as rapidocr
|
||||
use_cls: Optional[bool] = None # same default as rapidocr
|
||||
use_rec: Optional[bool] = None # same default as rapidocr
|
||||
|
||||
# class Device(Enum):
|
||||
# CPU = "CPU"
|
||||
# CUDA = "CUDA"
|
||||
# DIRECTML = "DIRECTML"
|
||||
# AUTO = "AUTO"
|
||||
|
||||
# device: Device = Device.AUTO # Default value is AUTO
|
||||
|
||||
print_verbose: bool = False # same default as rapidocr
|
||||
|
||||
det_model_path: Optional[str] = None # same default as rapidocr
|
||||
@ -243,110 +247,18 @@ class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
|
||||
return self.repo_id.replace("/", "--")
|
||||
|
||||
|
||||
# SmolVLM
|
||||
smolvlm_picture_description = PictureDescriptionVlmOptions(
|
||||
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"
|
||||
)
|
||||
# phi_picture_description = PictureDescriptionVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct")
|
||||
|
||||
# GraniteVision
|
||||
granite_picture_description = PictureDescriptionVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
||||
prompt="What is shown in this image?",
|
||||
)
|
||||
|
||||
|
||||
class BaseVlmOptions(BaseModel):
|
||||
kind: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class ResponseFormat(str, Enum):
|
||||
DOCTAGS = "doctags"
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
|
||||
|
||||
class InferenceFramework(str, Enum):
|
||||
MLX = "mlx"
|
||||
TRANSFORMERS = "transformers"
|
||||
OPENAI = "openai"
|
||||
TRANSFORMERS_AutoModelForVision2Seq = "transformers-AutoModelForVision2Seq"
|
||||
TRANSFORMERS_AutoModelForCausalLM = "transformers-AutoModelForCausalLM"
|
||||
TRANSFORMERS_LlavaForConditionalGeneration = (
|
||||
"transformers-LlavaForConditionalGeneration"
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceVlmOptions(BaseVlmOptions):
|
||||
kind: Literal["hf_model_options"] = "hf_model_options"
|
||||
|
||||
repo_id: str
|
||||
load_in_8bit: bool = True
|
||||
llm_int8_threshold: float = 6.0
|
||||
quantized: bool = False
|
||||
|
||||
inference_framework: InferenceFramework
|
||||
response_format: ResponseFormat
|
||||
|
||||
scale: float = 2.0
|
||||
|
||||
use_kv_cache: bool = True
|
||||
max_new_tokens: int = 4096
|
||||
|
||||
@property
|
||||
def repo_cache_folder(self) -> str:
|
||||
return self.repo_id.replace("/", "--")
|
||||
|
||||
|
||||
class ApiVlmOptions(BaseVlmOptions):
|
||||
kind: Literal["api_model_options"] = "api_model_options"
|
||||
|
||||
url: AnyUrl = AnyUrl(
|
||||
"http://localhost:11434/v1/chat/completions"
|
||||
) # Default to ollama
|
||||
headers: Dict[str, str] = {}
|
||||
params: Dict[str, Any] = {}
|
||||
scale: float = 2.0
|
||||
timeout: float = 60
|
||||
response_format: ResponseFormat
|
||||
|
||||
|
||||
smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
|
||||
prompt="Convert this page to docling.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.MLX,
|
||||
)
|
||||
|
||||
|
||||
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ds4sd/SmolDocling-256M-preview",
|
||||
prompt="Convert this page to docling.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
|
||||
)
|
||||
|
||||
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
||||
prompt="OCR the full page to markdown.",
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
|
||||
)
|
||||
|
||||
granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
|
||||
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
|
||||
params={"model": "granite3.2-vision:2b"},
|
||||
prompt="OCR the full page to markdown.",
|
||||
scale=1.0,
|
||||
timeout=120,
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
)
|
||||
|
||||
|
||||
class VlmModelType(str, Enum):
|
||||
SMOLDOCLING = "smoldocling"
|
||||
GRANITE_VISION = "granite_vision"
|
||||
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
||||
|
||||
|
||||
# Define an enum for the backend options
|
||||
class PdfBackend(str, Enum):
|
||||
"""Enum of valid PDF backends."""
|
||||
|
@ -29,7 +29,7 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
|
||||
self.vlm_options = vlm_options
|
||||
self.max_tokens = vlm_options.max_new_tokens
|
||||
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
from mlx_vlm import generate, load # type: ignore
|
||||
@ -60,7 +60,7 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
self.param_question = vlm_options.prompt
|
||||
|
||||
## Load the model
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
self.config = load_config(artifacts_path)
|
||||
|
||||
def __call__(
|
||||
@ -94,8 +94,8 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
_log.debug("start generating ...")
|
||||
|
||||
# Call model to generate:
|
||||
tokens:list[VlmPredictionToken] = []
|
||||
|
||||
tokens: list[VlmPredictionToken] = []
|
||||
|
||||
output = ""
|
||||
for token in self.stream_generate(
|
||||
self.vlm_model,
|
||||
@ -105,25 +105,40 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
max_tokens=4096,
|
||||
verbose=False,
|
||||
):
|
||||
if len(token.logprobs.shape)==1:
|
||||
tokens.append(VlmPredictionToken(text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[token.token]))
|
||||
elif len(token.logprobs.shape)==2 and token.logprobs.shape[0]==1:
|
||||
tokens.append(VlmPredictionToken(text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[0, token.token]))
|
||||
|
||||
output += token.text
|
||||
if len(token.logprobs.shape) == 1:
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[token.token],
|
||||
)
|
||||
)
|
||||
elif (
|
||||
len(token.logprobs.shape) == 2
|
||||
and token.logprobs.shape[0] == 1
|
||||
):
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[0, token.token],
|
||||
)
|
||||
)
|
||||
|
||||
output += token.text
|
||||
if "</doctag>" in token.text:
|
||||
break
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
page_tags = output
|
||||
|
||||
_log.debug(f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens)/generation_time} tokens/sec).")
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags,
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens)
|
||||
|
||||
_log.debug(
|
||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
||||
)
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=page_tags,
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens,
|
||||
)
|
||||
|
||||
yield page
|
||||
|
@ -43,17 +43,18 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
|
||||
self.device = decide_device(accelerator_options.device)
|
||||
|
||||
if self.device=="mlx":
|
||||
_log.warning(f"Mapping mlx to cpu for AutoModelForCausalLM")
|
||||
self.device = cpu
|
||||
|
||||
if self.device == "mlx":
|
||||
_log.warning(
|
||||
"Mapping mlx to cpu for AutoModelForCausalLM, use MLX framework!"
|
||||
)
|
||||
self.device = "cpu"
|
||||
|
||||
self.use_cache = vlm_options.use_kv_cache
|
||||
self.max_new_tokens = vlm_options.max_new_tokens
|
||||
|
||||
_log.debug(f"Available device for VLM: {self.device}")
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
# PARAMETERS:
|
||||
if artifacts_path is None:
|
||||
artifacts_path = HuggingFaceVlmModel.download_models(
|
||||
self.vlm_options.repo_id
|
||||
@ -117,8 +118,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
with TimeRecorder(conv_res, "vlm"):
|
||||
assert page.size is not None
|
||||
|
||||
hi_res_image = page.get_image(scale=self.vlm_options.scale) # 144dpi
|
||||
# hi_res_image = page.get_image(scale=1.0) # 72dpi
|
||||
hi_res_image = page.get_image(scale=self.vlm_options.scale)
|
||||
|
||||
if hi_res_image is not None:
|
||||
im_width, im_height = hi_res_image.size
|
||||
@ -157,14 +157,6 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
)
|
||||
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
||||
# print("")
|
||||
page.predictions.vlm_response = VlmPrediction(text=response)
|
||||
|
||||
yield page
|
||||
@ -172,11 +164,12 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
def formulate_prompt(self) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
||||
|
||||
user_prompt = "<|user|>"
|
||||
assistant_prompt = "<|assistant|>"
|
||||
prompt_suffix = "<|end|>"
|
||||
|
||||
# prompt = f"{user_prompt}<|image_1|>Convert this image into MarkDown and only return the bare MarkDown!{prompt_suffix}{assistant_prompt}"
|
||||
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
|
||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||
|
||||
|
@ -38,10 +38,8 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
device = decide_device(accelerator_options.device)
|
||||
self.device = device
|
||||
|
||||
_log.debug(f"Available device for HuggingFace VLM: {device}")
|
||||
self.device = decide_device(accelerator_options.device)
|
||||
_log.debug(f"Available device for HuggingFace VLM: {self.device}")
|
||||
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
@ -54,7 +52,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
||||
# self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
||||
self.param_quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=vlm_options.load_in_8bit, # True,
|
||||
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
|
||||
@ -68,7 +66,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
if not self.param_quantized:
|
||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=device,
|
||||
device_map=self.device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
@ -82,7 +80,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
else:
|
||||
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=device,
|
||||
device_map=self.device,
|
||||
torch_dtype="auto",
|
||||
quantization_config=self.param_quantization_config,
|
||||
_attn_implementation=(
|
||||
@ -94,29 +92,6 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
# trust_remote_code=True,
|
||||
) # .to(self.device)
|
||||
|
||||
"""
|
||||
@staticmethod
|
||||
def download_models(
|
||||
repo_id: str,
|
||||
local_dir: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
if not progress:
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id=repo_id,
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
# revision="v0.0.1",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@ -128,8 +103,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
with TimeRecorder(conv_res, "vlm"):
|
||||
assert page.size is not None
|
||||
|
||||
hi_res_image = page.get_image(scale=2.0) # 144dpi
|
||||
# hi_res_image = page.get_image(scale=1.0) # 72dpi
|
||||
hi_res_image = page.get_image(scale=self.vlm_options.scale)
|
||||
|
||||
if hi_res_image is not None:
|
||||
im_width, im_height = hi_res_image.size
|
||||
@ -141,22 +115,9 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
if hi_res_image.mode != "RGB":
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": self.param_question},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=False
|
||||
)
|
||||
# Define prompt structure
|
||||
prompt = self.formulate_prompt()
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
||||
)
|
||||
@ -180,14 +141,26 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
)
|
||||
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
||||
# print("")
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||
|
||||
yield page
|
||||
|
||||
def formulate_prompt(self) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": self.vlm_options.prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=False
|
||||
)
|
||||
return prompt
|
||||
|
@ -39,10 +39,15 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
)
|
||||
|
||||
self.device = decide_device(accelerator_options.device)
|
||||
self.device = "cpu" # FIXME
|
||||
|
||||
self.use_cache = True
|
||||
self.max_new_tokens = 64 # FIXME
|
||||
if self.device == "mlx":
|
||||
_log.warning(
|
||||
"Mapping mlx to cpu for LlavaForConditionalGeneration, use MLX framework!"
|
||||
)
|
||||
self.device = "cpu"
|
||||
|
||||
self.use_cache = vlm_options.use_kv_cache
|
||||
self.max_new_tokens = vlm_options.max_new_tokens
|
||||
|
||||
_log.debug(f"Available device for VLM: {self.device}")
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
@ -54,9 +59,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
model_path = artifacts_path
|
||||
_log.debug(f"model: {model_path}")
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
artifacts_path,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
@ -98,12 +100,11 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
images = [hi_res_image]
|
||||
|
||||
# Define prompt structure
|
||||
# prompt = "<s>[INST]Describe the images.\n[IMG][/INST]"
|
||||
prompt = self.formulate_prompt()
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompt, images=images, return_tensors="pt"
|
||||
).to(self.device) # .to("cuda")
|
||||
).to(self.device)
|
||||
|
||||
# Generate response
|
||||
start_time = time.time()
|
||||
@ -113,8 +114,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
use_cache=self.use_cache, # Enables KV caching which can improve performance
|
||||
)
|
||||
|
||||
print(generate_ids)
|
||||
|
||||
num_tokens = len(generate_ids[0])
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
@ -123,10 +122,12 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)[0]
|
||||
|
||||
page.predictions.vlm_response = VlmPrediction(text=response,
|
||||
generated_tokens=num_tokens,
|
||||
generation_time=generation_time)
|
||||
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=response,
|
||||
generated_tokens=num_tokens,
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
yield page
|
||||
|
||||
|
@ -13,11 +13,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat, Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import (
|
||||
from docling.datamodel.pipeline_model_specializations import (
|
||||
ApiVlmOptions,
|
||||
HuggingFaceVlmOptions,
|
||||
InferenceFramework,
|
||||
ResponseFormat,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import (
|
||||
VlmPipelineOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
|
@ -2,10 +2,12 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.pipeline_options import (
|
||||
granite_picture_description,
|
||||
from docling.datamodel.pipeline_model_specializations import (
|
||||
smoldocling_vlm_conversion_options,
|
||||
smoldocling_vlm_mlx_conversion_options,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import (
|
||||
granite_picture_description,
|
||||
smolvlm_picture_description,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
|
@ -11,10 +11,15 @@ from docling.datamodel.pipeline_options import (
|
||||
InferenceFramework,
|
||||
ResponseFormat,
|
||||
VlmPipelineOptions,
|
||||
smoldocling_vlm_mlx_conversion_options,
|
||||
smoldocling_vlm_conversion_options,
|
||||
granite_vision_vlm_conversion_options,
|
||||
granite_vision_vlm_mlx_conversion_options,
|
||||
granite_vision_vlm_ollama_conversion_options,
|
||||
phi_vlm_conversion_options,
|
||||
pixtral_12b_vlm_conversion_options,
|
||||
pixtral_12b_vlm_mlx_conversion_options,
|
||||
qwen25_vl_3b_vlm_mlx_conversion_options,
|
||||
smoldocling_vlm_conversion_options,
|
||||
smoldocling_vlm_mlx_conversion_options,
|
||||
)
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||
@ -28,6 +33,7 @@ sources = [
|
||||
pipeline_options = VlmPipelineOptions()
|
||||
# If force_backend_text = True, text from backend will be used instead of generated text
|
||||
pipeline_options.force_backend_text = False
|
||||
pipeline_options.generate_page_images = True
|
||||
|
||||
## On GPU systems, enable flash_attention_2 with CUDA:
|
||||
# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA
|
||||
@ -37,11 +43,13 @@ pipeline_options.force_backend_text = False
|
||||
# pipeline_options.vlm_options = smoldocling_vlm_conversion_options
|
||||
|
||||
## Pick a VLM model. Fast Apple Silicon friendly implementation for SmolDocling-256M via MLX
|
||||
pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
|
||||
# pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
|
||||
|
||||
## Alternative VLM models:
|
||||
# pipeline_options.vlm_options = granite_vision_vlm_conversion_options
|
||||
|
||||
pipeline_options.vlm_options = phi_vlm_conversion_options
|
||||
|
||||
"""
|
||||
pixtral_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="mistralai/Pixtral-12B-Base-2409",
|
||||
@ -105,7 +113,7 @@ converter = DocumentConverter(
|
||||
pipeline_cls=VlmPipeline,
|
||||
pipeline_options=pipeline_options,
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
out_path = Path("scratch")
|
||||
@ -121,39 +129,44 @@ for source in sources:
|
||||
res = converter.convert(source)
|
||||
|
||||
print("")
|
||||
#print(res.document.export_to_markdown())
|
||||
# print(res.document.export_to_markdown())
|
||||
|
||||
for i,page in enumerate(res.pages):
|
||||
model_id = pipeline_options.vlm_options.repo_id.replace("/", "_")
|
||||
fname = f"{model_id}-{res.input.file.stem}"
|
||||
|
||||
for i, page in enumerate(res.pages):
|
||||
print("")
|
||||
print(f" ---------- Predicted page {i} in {pipeline_options.vlm_options.response_format}:")
|
||||
print(
|
||||
f" ---------- Predicted page {i} in {pipeline_options.vlm_options.response_format}:"
|
||||
)
|
||||
print(page.predictions.vlm_response.text)
|
||||
print(f" ---------- ")
|
||||
print(" ---------- ")
|
||||
|
||||
print("===== Final output of the converted document =======")
|
||||
|
||||
with (out_path / f"{res.input.file.stem}.json").open("w") as fp:
|
||||
|
||||
with (out_path / f"{fname}.json").open("w") as fp:
|
||||
fp.write(json.dumps(res.document.export_to_dict()))
|
||||
|
||||
res.document.save_as_json(
|
||||
out_path / f"{res.input.file.stem}.json",
|
||||
out_path / f"{fname}.json",
|
||||
image_mode=ImageRefMode.PLACEHOLDER,
|
||||
)
|
||||
print(f" => produced {out_path / res.input.file.stem}.json")
|
||||
|
||||
print(f" => produced {out_path / fname}.json")
|
||||
|
||||
res.document.save_as_markdown(
|
||||
out_path / f"{res.input.file.stem}.md",
|
||||
out_path / f"{fname}.md",
|
||||
image_mode=ImageRefMode.PLACEHOLDER,
|
||||
)
|
||||
print(f" => produced {out_path / res.input.file.stem}.md")
|
||||
|
||||
print(f" => produced {out_path / fname}.md")
|
||||
|
||||
res.document.save_as_html(
|
||||
out_path / f"{res.input.file.stem}.html",
|
||||
out_path / f"{fname}.html",
|
||||
image_mode=ImageRefMode.EMBEDDED,
|
||||
labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE],
|
||||
# split_page_view=True,
|
||||
split_page_view=True,
|
||||
)
|
||||
print(f" => produced {out_path / res.input.file.stem}.html")
|
||||
|
||||
print(f" => produced {out_path / fname}.html")
|
||||
|
||||
pg_num = res.document.num_pages()
|
||||
print("")
|
||||
inference_time = time.time() - start_time
|
||||
@ -161,4 +174,3 @@ for source in sources:
|
||||
f"Total document prediction time: {inference_time:.2f} seconds, pages: {pg_num}"
|
||||
)
|
||||
print("====================================================")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user