fixed the MyPy complaining

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-07-09 06:48:03 +02:00
parent c10e2920a4
commit dcf6fd6a41
5 changed files with 91 additions and 28 deletions

View File

@ -11,12 +11,13 @@ from docling.datamodel.pipeline_options_asr_model import (
# ApiAsrOptions,
InferenceAsrFramework,
InlineAsrNativeWhisperOptions,
InlineAsrOptions,
TransformersModelType,
)
_log = logging.getLogger(__name__)
WHISPER_TINY = InlineAsrNativeWhisperOptions(
WHISPER_TINY: InlineAsrOptions = InlineAsrNativeWhisperOptions(
repo_id="tiny",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -27,7 +28,7 @@ WHISPER_TINY = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_SMALL = InlineAsrNativeWhisperOptions(
WHISPER_SMALL: InlineAsrOptions = InlineAsrNativeWhisperOptions(
repo_id="small",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -38,7 +39,7 @@ WHISPER_SMALL = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
WHISPER_MEDIUM: InlineAsrOptions = InlineAsrNativeWhisperOptions(
repo_id="medium",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -49,7 +50,7 @@ WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_BASE = InlineAsrNativeWhisperOptions(
WHISPER_BASE: InlineAsrOptions = InlineAsrNativeWhisperOptions(
repo_id="base",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -60,7 +61,7 @@ WHISPER_BASE = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_LARGE = InlineAsrNativeWhisperOptions(
WHISPER_LARGE: InlineAsrOptions = InlineAsrNativeWhisperOptions(
repo_id="large",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,
@ -71,7 +72,7 @@ WHISPER_LARGE = InlineAsrNativeWhisperOptions(
max_time_chunk=30.0,
)
WHISPER_TURBO = InlineAsrNativeWhisperOptions(
WHISPER_TURBO: InlineAsrOptions = InlineAsrNativeWhisperOptions(
repo_id="turbo",
inference_framework=InferenceAsrFramework.WHISPER,
verbose=True,

View File

@ -16,7 +16,15 @@ from docling.datamodel import asr_model_specs
# Import the following for backwards compatibility
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import WHISPER_TINY as whisper_tiny
from docling.datamodel.asr_model_specs import (
WHISPER_BASE,
WHISPER_LARGE,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TINY as whisper_tiny,
WHISPER_TURBO,
)
from docling.datamodel.layout_model_specs import (
LayoutModelConfig,
docling_layout_egret_large,
@ -279,13 +287,12 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
class LayoutOptions(BaseModel):
"""Options for layout processing."""
repo_id: str = "ds4sd/docling-layout-heron"
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
model_spec: LayoutModelConfig = docling_layout_v2
class AsrPipelineOptions(PipelineOptions):
asr_options: Union[InlineAsrOptions] = asr_model_specs.WHISPER_TINY
asr_options: Union[InlineAsrOptions] = WHISPER_TINY
artifacts_path: Optional[Union[Path, str]] = None

View File

@ -16,7 +16,7 @@ from docling.datamodel.document import ConversionResult
from docling.datamodel.layout_model_specs import LayoutModelConfig, docling_layout_v2
from docling.datamodel.pipeline_options import LayoutOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseLayoutModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
@ -26,7 +26,7 @@ from docling.utils.visualization import draw_clusters
_log = logging.getLogger(__name__)
class LayoutModel(BasePageModel):
class LayoutModel(BaseLayoutModel):
TEXT_ELEM_LABELS = [
DocItemLabel.TEXT,
DocItemLabel.FOOTNOTE,
@ -179,7 +179,9 @@ class LayoutModel(BasePageModel):
)
clusters.append(cluster)
"""
predicted_clusters = self.predict_on_page(page_image=page_image)
predicted_clusters = self.predict_on_page_image(
page_image=page_image
)
if settings.debug.visualize_raw_layout:
self.draw_clusters_and_cells_side_by_side(
@ -216,7 +218,9 @@ class LayoutModel(BasePageModel):
)
"""
page, processed_clusters, processed_cells = (
self.postprocess_on_page(page=page, clusters=predicted_clusters)
self.postprocess_on_page_image(
page=page, clusters=predicted_clusters
)
)
with warnings.catch_warnings():
@ -244,7 +248,7 @@ class LayoutModel(BasePageModel):
yield page
def predict_on_page(self, *, page_image: Image.Image) -> list[Cluster]:
def predict_on_page_image(self, *, page_image: Image.Image) -> list[Cluster]:
pred_items = self.layout_predictor.predict(page_image)
clusters = []
@ -263,7 +267,7 @@ class LayoutModel(BasePageModel):
return clusters
def postprocess_on_page(
def postprocess_on_page_image(
self, *, page: Page, clusters: list[Cluster]
) -> tuple[Page, list[Cluster], list[TextCell]]:
processed_clusters, processed_cells = LayoutPostprocessor(

View File

@ -5,10 +5,12 @@ from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
from PIL import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
@ -122,6 +124,43 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi
# Load generation config
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
def get_user_prompt(self, page: Optional[Page]) -> str:
# Define prompt structure
user_prompt = ""
if callable(self.vlm_options.prompt) and page is not None:
user_prompt = self.vlm_options.prompt(page.parsed_page)
elif isinstance(self.vlm_options.prompt, str):
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt)
return prompt
def predict_on_page_image(
self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False
) -> tuple[str, Optional[list[VlmPredictionToken]]]:
output = ""
inputs = self.processor(
text=prompt, images=[page_image], return_tensors="pt"
).to(self.device)
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
output = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
return output, []
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -133,22 +172,29 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(
page_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
assert page_image is not None
# Define prompt structure
"""
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
).to(self.device)
"""
prompt = self.get_user_prompt(page=page)
start_time = time.time()
"""
inputs = self.processor(
text=prompt, images=[page_image], return_tensors="pt"
).to(self.device)
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
@ -169,9 +215,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
"""
generated_text = self.predict_on_page_image(
page_image=page_image, prompt=prompt, output_tokens=False
)
page.predictions.vlm_response = VlmPrediction(
text=generated_texts,
generation_time=generation_time,
text=generated_text,
generation_time=time.time() - start_time,
)
yield page

View File

@ -115,7 +115,7 @@ class VlmPipeline(PaginatedPipeline):
TwoStageVlmOptions, self.pipeline_options.vlm_options
)
layout_options = twostagevlm_options.lay_options
layout_options = twostagevlm_options.layout_options
vlm_options = twostagevlm_options.vlm_options
layout_model = LayoutModel(
@ -125,24 +125,24 @@ class VlmPipeline(PaginatedPipeline):
)
if vlm_options.inference_framework == InferenceFramework.MLX:
vlm_model = HuggingFaceMlxModel(
vlm_model_mlx = HuggingFaceMlxModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
)
self.build_pipe = [
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model_mlx)
]
elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
vlm_model = HuggingFaceTransformersVlmModel(
vlm_model_hf = HuggingFaceTransformersVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
)
self.build_pipe = [
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model_hf)
]
else:
raise ValueError(