mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
Addressing PR comments, added enabled property to SmolDocling, and related VLM pipeline option, few other minor things
Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
parent
b0935daec4
commit
853544ba11
@ -255,9 +255,7 @@ granite_picture_description = PictureDescriptionVlmOptions(
|
|||||||
|
|
||||||
|
|
||||||
class SmolDoclingOptions(BaseModel):
|
class SmolDoclingOptions(BaseModel):
|
||||||
artifacts_path: str = ""
|
question: str = "Convert this page to docling."
|
||||||
question: str = "Convert this page to docling." # "Perform Layout Analysis."
|
|
||||||
|
|
||||||
load_in_8bit: bool = True
|
load_in_8bit: bool = True
|
||||||
llm_int8_threshold: float = 6.0
|
llm_int8_threshold: float = 6.0
|
||||||
quantized: bool = False
|
quantized: bool = False
|
||||||
@ -294,7 +292,24 @@ class PipelineOptions(BaseModel):
|
|||||||
enable_remote_services: bool = False
|
enable_remote_services: bool = False
|
||||||
|
|
||||||
|
|
||||||
class PdfPipelineOptions(PipelineOptions):
|
class PaginatedPipelineOptions(PipelineOptions):
|
||||||
|
images_scale: float = 1.0
|
||||||
|
generate_page_images: bool = False
|
||||||
|
generate_picture_images: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class VlmPipelineOptions(PaginatedPipelineOptions):
|
||||||
|
artifacts_path: Optional[Union[Path, str]] = None
|
||||||
|
do_vlm: bool = True # True: perform inference of Visual Language Model
|
||||||
|
|
||||||
|
force_backend_text: bool = (
|
||||||
|
False # (To be used with vlms, or other generative models)
|
||||||
|
)
|
||||||
|
# If True, text from backend will be used instead of generated text
|
||||||
|
vlm_options: Union[SmolDoclingOptions,] = Field(SmolDoclingOptions())
|
||||||
|
|
||||||
|
|
||||||
|
class PdfPipelineOptions(PaginatedPipelineOptions):
|
||||||
"""Options for the PDF pipeline."""
|
"""Options for the PDF pipeline."""
|
||||||
|
|
||||||
artifacts_path: Optional[Union[Path, str]] = None
|
artifacts_path: Optional[Union[Path, str]] = None
|
||||||
|
@ -3,14 +3,6 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS
|
|
||||||
from transformers import ( # type: ignore
|
|
||||||
AutoProcessor,
|
|
||||||
BitsAndBytesConfig,
|
|
||||||
Idefics3ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from docling.datamodel.base_models import DocTagsPrediction, Page
|
from docling.datamodel.base_models import DocTagsPrediction, Page
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
@ -32,38 +24,76 @@ class SmolDoclingModel(BasePageModel):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
artifacts_path: Optional[Path],
|
||||||
accelerator_options: AcceleratorOptions,
|
accelerator_options: AcceleratorOptions,
|
||||||
vlm_options: SmolDoclingOptions,
|
vlm_options: SmolDoclingOptions,
|
||||||
):
|
):
|
||||||
device = decide_device(accelerator_options.device)
|
self.enabled = enabled
|
||||||
self.device = device
|
|
||||||
_log.info("Available device for SmolDocling: {}".format(device))
|
|
||||||
|
|
||||||
# PARAMETERS:
|
if self.enabled:
|
||||||
artifacts_path = Path(vlm_options.artifacts_path)
|
import torch
|
||||||
self.param_question = vlm_options.question # "Perform Layout Analysis."
|
from transformers import ( # type: ignore
|
||||||
self.param_quantization_config = BitsAndBytesConfig(
|
AutoProcessor,
|
||||||
load_in_8bit=vlm_options.load_in_8bit, # True,
|
BitsAndBytesConfig,
|
||||||
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
|
Idefics3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = decide_device(accelerator_options.device)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
_log.debug("Available device for SmolDocling: {}".format(device))
|
||||||
|
|
||||||
|
repo_cache_folder = self._repo_id.replace("/", "--")
|
||||||
|
|
||||||
|
# PARAMETERS:
|
||||||
|
if artifacts_path is None:
|
||||||
|
artifacts_path = self.download_models()
|
||||||
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
|
self.param_question = vlm_options.question # "Perform Layout Analysis."
|
||||||
|
self.param_quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_8bit=vlm_options.load_in_8bit, # True,
|
||||||
|
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
|
||||||
|
)
|
||||||
|
self.param_quantized = vlm_options.quantized # False
|
||||||
|
|
||||||
|
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
||||||
|
if not self.param_quantized:
|
||||||
|
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
|
||||||
|
artifacts_path,
|
||||||
|
# device_map=device,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
self.vlm_model = self.vlm_model.to(device)
|
||||||
|
else:
|
||||||
|
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
|
||||||
|
artifacts_path,
|
||||||
|
# device_map=device,
|
||||||
|
torch_dtype="auto",
|
||||||
|
quantization_config=self.param_quantization_config,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download_models(
|
||||||
|
local_dir: Optional[Path] = None,
|
||||||
|
force: bool = False,
|
||||||
|
progress: bool = False,
|
||||||
|
) -> Path:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import disable_progress_bars
|
||||||
|
|
||||||
|
if not progress:
|
||||||
|
disable_progress_bars()
|
||||||
|
download_path = snapshot_download(
|
||||||
|
repo_id=SmolDoclingModel._repo_id,
|
||||||
|
force_download=force,
|
||||||
|
local_dir=local_dir,
|
||||||
|
# revision="v0.0.1",
|
||||||
)
|
)
|
||||||
self.param_quantized = vlm_options.quantized # False
|
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
return Path(download_path)
|
||||||
if not self.param_quantized:
|
|
||||||
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
|
|
||||||
artifacts_path,
|
|
||||||
device_map=device,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
# _attn_implementation="flash_attention_2",
|
|
||||||
)
|
|
||||||
self.vlm_model = self.vlm_model.to(device)
|
|
||||||
else:
|
|
||||||
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
|
|
||||||
artifacts_path,
|
|
||||||
device_map=device,
|
|
||||||
torch_dtype="auto",
|
|
||||||
quantization_config=self.param_quantization_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
@ -116,7 +116,10 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
|||||||
|
|
||||||
def __init__(self, pipeline_options: PipelineOptions):
|
def __init__(self, pipeline_options: PipelineOptions):
|
||||||
super().__init__(pipeline_options)
|
super().__init__(pipeline_options)
|
||||||
self.keep_backend = True
|
self.keep_backend = (
|
||||||
|
True # For now, need to be able to query for page size post prediction
|
||||||
|
)
|
||||||
|
# self.keep_backend = False
|
||||||
|
|
||||||
def _apply_on_pages(
|
def _apply_on_pages(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
@ -56,7 +56,6 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
|
|
||||||
def __init__(self, pipeline_options: PdfPipelineOptions):
|
def __init__(self, pipeline_options: PdfPipelineOptions):
|
||||||
super().__init__(pipeline_options)
|
super().__init__(pipeline_options)
|
||||||
print("------> Init Standard PDF Pipeline!")
|
|
||||||
self.pipeline_options: PdfPipelineOptions
|
self.pipeline_options: PdfPipelineOptions
|
||||||
|
|
||||||
artifacts_path: Optional[Path] = None
|
artifacts_path: Optional[Path] = None
|
||||||
|
@ -97,6 +97,8 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
|
|
||||||
self.build_pipe = [
|
self.build_pipe = [
|
||||||
SmolDoclingModel(
|
SmolDoclingModel(
|
||||||
|
enabled=pipeline_options.do_vlm,
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
vlm_options=self.pipeline_options.vlm_options,
|
vlm_options=self.pipeline_options.vlm_options,
|
||||||
),
|
),
|
||||||
@ -297,6 +299,7 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
token
|
token
|
||||||
for token in tokens
|
for token in tokens
|
||||||
if not (token.startswith("<loc_") or token in ["<otsl>", "</otsl>"])
|
if not (token.startswith("<loc_") or token in ["<otsl>", "</otsl>"])
|
||||||
|
# if not (token.startswith(DocumentToken.BEG_LOC) or token in [DocumentToken.BEG_OTSL, DocumentToken.END_OTSL])
|
||||||
]
|
]
|
||||||
# Split the string by those tokens to get the in-between text
|
# Split the string by those tokens to get the in-between text
|
||||||
text_parts = re.split(pattern, s)
|
text_parts = re.split(pattern, s)
|
||||||
@ -304,6 +307,7 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
token
|
token
|
||||||
for token in text_parts
|
for token in text_parts
|
||||||
if not (token.startswith("<loc_") or token in ["<otsl>", "</otsl>"])
|
if not (token.startswith("<loc_") or token in ["<otsl>", "</otsl>"])
|
||||||
|
# if not (token.startswith(DocumentToken.BEG_LOC) or token in [DocumentToken.BEG_OTSL, DocumentToken.END_OTSL])
|
||||||
]
|
]
|
||||||
# Remove any empty or purely whitespace strings from text_parts
|
# Remove any empty or purely whitespace strings from text_parts
|
||||||
text_parts = [part for part in text_parts if part.strip()]
|
text_parts = [part for part in text_parts if part.strip()]
|
||||||
@ -347,10 +351,15 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
|
|
||||||
# Regex for all recognized tags
|
# Regex for all recognized tags
|
||||||
tag_pattern = (
|
tag_pattern = (
|
||||||
r"<(?P<tag>title|document_index|otsl|section_header_level_1|checkbox_selected|"
|
rf"<(?P<tag>{DocItemLabel.TITLE}|{DocItemLabel.DOCUMENT_INDEX}|"
|
||||||
r"checkbox_unselected|text|page_header|page_footer|formula|caption|picture|"
|
rf"{DocItemLabel.CHECKBOX_UNSELECTED}|{DocItemLabel.CHECKBOX_SELECTED}|"
|
||||||
r"list_item|footnote|code)>.*?</(?P=tag)>"
|
rf"{DocItemLabel.TEXT}|{DocItemLabel.PAGE_HEADER}|"
|
||||||
|
rf"{DocItemLabel.PAGE_FOOTER}|{DocItemLabel.FORMULA}|"
|
||||||
|
rf"{DocItemLabel.CAPTION}|{DocItemLabel.PICTURE}|"
|
||||||
|
rf"{DocItemLabel.LIST_ITEM}|{DocItemLabel.FOOTNOTE}|{DocItemLabel.CODE}|"
|
||||||
|
rf"{DocItemLabel.SECTION_HEADER}_level_1|otsl)>.*?</(?P=tag)>"
|
||||||
)
|
)
|
||||||
|
|
||||||
pattern = re.compile(tag_pattern, re.DOTALL)
|
pattern = re.compile(tag_pattern, re.DOTALL)
|
||||||
|
|
||||||
# Go through each match in order
|
# Go through each match in order
|
||||||
@ -438,8 +447,8 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
return doc
|
return doc
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_options(cls) -> PdfPipelineOptions:
|
def get_default_options(cls) -> VlmPipelineOptions:
|
||||||
return PdfPipelineOptions()
|
return VlmPipelineOptions()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_backend_supported(cls, backend: AbstractDocumentBackend):
|
def is_backend_supported(cls, backend: AbstractDocumentBackend):
|
||||||
|
@ -19,7 +19,7 @@ pipeline_options = VlmPipelineOptions() # artifacts_path="~/local_model_artifac
|
|||||||
pipeline_options.generate_page_images = True
|
pipeline_options.generate_page_images = True
|
||||||
# If force_backend_text = True, text from backend will be used instead of generated text
|
# If force_backend_text = True, text from backend will be used instead of generated text
|
||||||
pipeline_options.force_backend_text = False
|
pipeline_options.force_backend_text = False
|
||||||
|
# pipeline_options.do_vlm = True - use False to disable VLM model (i.e. SmallDocling), extra python imports will not be performed
|
||||||
|
|
||||||
vlm_options = SmolDoclingOptions(
|
vlm_options = SmolDoclingOptions(
|
||||||
# question="Convert this page to docling.",
|
# question="Convert this page to docling.",
|
||||||
|
Loading…
Reference in New Issue
Block a user