mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
feat: working on a two stage VLM model
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
4eceefa47c
commit
810446c8dc
@ -16,6 +16,7 @@ from docling.datamodel import asr_model_specs
|
||||
|
||||
# Import the following for backwards compatibility
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.asr_model_specs import WHISPER_TINY as whisper_tiny
|
||||
from docling.datamodel.pipeline_options_asr_model import (
|
||||
InlineAsrOptions,
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from pydantic import AnyUrl, BaseModel
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||
from docling.datamodel.pipeline_options import LayoutOptions
|
||||
|
||||
|
||||
class BaseVlmOptions(BaseModel):
|
||||
@ -90,7 +91,7 @@ class ApiVlmOptions(BaseVlmOptions):
|
||||
|
||||
|
||||
class TwoStageVlmOptions(BaseVlmOptions):
|
||||
kind: Literal["inline_model_options"] = "inline_two_stage_model_options"
|
||||
kind: Literal["inline_two_stage_model_options"] = "inline_two_stage_model_options"
|
||||
|
||||
vlm_options: UnionInlineVlmOptions
|
||||
vlm_options: InlineVlmOptions
|
||||
layout_options: LayoutOptions
|
||||
|
@ -3,6 +3,7 @@ from collections.abc import Iterable
|
||||
from typing import Generic, Optional, Protocol, Type
|
||||
|
||||
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
||||
from PIL import Image
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
||||
@ -19,12 +20,25 @@ class BaseModelWithOptions(Protocol):
|
||||
|
||||
|
||||
class BasePageModel(ABC):
|
||||
scale: float # scale with which the page-image needs to be created (dpi = 72*scale)
|
||||
max_size: int # max size of width/height of page-image
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
pass
|
||||
|
||||
class BaseLayoutModel(BasePageModel):
|
||||
@abstractmethod
|
||||
def predict_on_page_image(self, *, page_image: Image.Image) -> list(Cluster):
|
||||
pass
|
||||
|
||||
class BaseVlmModel(BasePageModel):
|
||||
@abstractmethod
|
||||
def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import DocItemLabel
|
||||
from docling_core.types.doc.page import TextCell
|
||||
from PIL import Image
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||
@ -176,11 +177,11 @@ class LayoutModel(BasePageModel):
|
||||
)
|
||||
clusters.append(cluster)
|
||||
"""
|
||||
clusters = self.predict_on_page(page_image)
|
||||
|
||||
predicted_clusters = self.predict_on_page(page_image=page_image)
|
||||
|
||||
if settings.debug.visualize_raw_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, clusters, mode_prefix="raw"
|
||||
conv_res, page, predicted_clusters, mode_prefix="raw"
|
||||
)
|
||||
|
||||
# Apply postprocessing
|
||||
@ -212,8 +213,28 @@ class LayoutModel(BasePageModel):
|
||||
clusters=processed_clusters
|
||||
)
|
||||
"""
|
||||
page, processed_clusters, processed_cells = self.postprocess_on_page(page, cluster)
|
||||
|
||||
page, processed_clusters, processed_cells = (
|
||||
self.postprocess_on_page(page=page, clusters=predicted_clusters)
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"Mean of empty slice|invalid value encountered in scalar divide",
|
||||
RuntimeWarning,
|
||||
"numpy",
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].layout_score = float(
|
||||
np.mean([c.confidence for c in processed_clusters])
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].ocr_score = float(
|
||||
np.mean(
|
||||
[c.confidence for c in processed_cells if c.from_ocr]
|
||||
)
|
||||
)
|
||||
|
||||
if settings.debug.visualize_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||
@ -221,17 +242,13 @@ class LayoutModel(BasePageModel):
|
||||
|
||||
yield page
|
||||
|
||||
def predict_on_page(self, page_image: Image) -> list[Cluster]:
|
||||
|
||||
def predict_on_page(self, *, page_image: Image.Image) -> list[Cluster]:
|
||||
pred_items = self.layout_predictor.predict(page_image)
|
||||
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(pred_items):
|
||||
label = DocItemLabel(
|
||||
pred_item["label"]
|
||||
.lower()
|
||||
.replace(" ", "_")
|
||||
.replace("-", "_")
|
||||
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
|
||||
) # Temporary, until docling-ibm-model uses docling-core types
|
||||
cluster = Cluster(
|
||||
id=ix,
|
||||
@ -241,36 +258,17 @@ class LayoutModel(BasePageModel):
|
||||
cells=[],
|
||||
)
|
||||
clusters.append(cluster)
|
||||
|
||||
|
||||
return clusters
|
||||
|
||||
def postprocess_on_page(self, page: Page, cluster: list(Cluster)):
|
||||
|
||||
def postprocess_on_page(
|
||||
self, *, page: Page, clusters: list[Cluster]
|
||||
) -> tuple[Page, list[Cluster], list[TextCell]]:
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page, clusters, self.options
|
||||
).postprocess()
|
||||
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"Mean of empty slice|invalid value encountered in scalar divide",
|
||||
RuntimeWarning,
|
||||
"numpy",
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].layout_score = float(
|
||||
np.mean([c.confidence for c in processed_clusters])
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].ocr_score = float(
|
||||
np.mean(
|
||||
[c.confidence for c in processed_cells if c.from_ocr]
|
||||
)
|
||||
)
|
||||
|
||||
page.predictions.layout = LayoutPrediction(
|
||||
clusters=processed_clusters
|
||||
)
|
||||
page.predictions.layout = LayoutPrediction(clusters=processed_clusters)
|
||||
|
||||
return page, processed_clusters, processed_cells
|
||||
|
@ -15,7 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
TransformersModelType,
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.base_model import BasePageModel, BaseVlmModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@ -25,7 +25,7 @@ from docling.utils.profiling import TimeRecorder
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@ -37,6 +37,9 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
self.scale = self.vlm_options.scale
|
||||
self.max_size = self.vlm_options.max_size
|
||||
|
||||
if self.enabled:
|
||||
import torch
|
||||
from transformers import (
|
||||
|
@ -10,7 +10,7 @@ from docling.datamodel.accelerator_options import (
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.base_model import BasePageModel, BaseVlmModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@ -19,7 +19,7 @@ from docling.utils.profiling import TimeRecorder
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@ -28,10 +28,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
vlm_options: InlineVlmOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
self.max_tokens = vlm_options.max_new_tokens
|
||||
self.temperature = vlm_options.temperature
|
||||
self.scale = self.vlm_options.scale
|
||||
self.max_size = self.vlm_options.max_size
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
|
@ -15,7 +15,8 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
TransformersModelType,
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.base_model import BasePageModel, BaseVlmModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
@ -29,11 +30,11 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
layout_model: LayoutModelModel,
|
||||
vlm_model: BasePageModel,
|
||||
layout_model: LayoutModel,
|
||||
vlm_model: BaseVlmModel,
|
||||
):
|
||||
self.layout_model = layout_options
|
||||
self.vlm_model = vlm_options
|
||||
self.layout_model = layout_model
|
||||
self.vlm_model = vlm_model
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
@ -47,23 +48,27 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
assert page.size is not None
|
||||
|
||||
page_image = page.get_image(
|
||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||
scale=self.vlm_model.scale, max_size=self.vlm_model.max_size
|
||||
)
|
||||
|
||||
pred_clusters = self.layout_model.predict_on_page(page_image=page_image)
|
||||
page, processed_clusters, processed_cells = (
|
||||
self.layout_model.postprocess_on_page(
|
||||
page=page, clusters=pred_clusters
|
||||
)
|
||||
)
|
||||
|
||||
pred_clusters = self.layout_model.predict_on_page(page_image)
|
||||
page, processed_clusters, processed_cells = self.layout_model.postprocess_on_page(page=page,
|
||||
page_image=page_image)
|
||||
|
||||
# Define prompt structure
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
user_prompt = self.vlm_options.prompt
|
||||
|
||||
|
||||
prompt = self.formulate_prompt(user_prompt, processed_clusters)
|
||||
|
||||
generated_text, generation_time = self.vlm_model.predict_on_image(page_image=page_image,
|
||||
prompt=prompt)
|
||||
generated_text, generation_time = self.vlm_model.predict_on_image(
|
||||
page_image=page_image, prompt=prompt
|
||||
)
|
||||
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=generated_text,
|
||||
@ -72,7 +77,7 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
|
||||
yield page
|
||||
|
||||
def formulate_prompt(self, user_prompt: str, clusters:list[Cluster]) -> str:
|
||||
def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
|
||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
@ -26,9 +26,7 @@ from docling.backend.md_backend import MarkdownDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat, Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import (
|
||||
VlmPipelineOptions,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
ApiVlmOptions,
|
||||
InferenceFramework,
|
||||
@ -37,15 +35,17 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.api_vlm_model import ApiVlmModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.vlm_models_inline.hf_transformers_model import (
|
||||
HuggingFaceTransformersVlmModel,
|
||||
)
|
||||
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
|
||||
from docling.models.vlm_models_inline.two_stage_vlm_model import (
|
||||
TwoStageVlmModel,
|
||||
)
|
||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
from docling.models.layout_model import LayoutModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -110,7 +110,9 @@ class VlmPipeline(PaginatedPipeline):
|
||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||
)
|
||||
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
|
||||
twostagevlm_options = cast(TwoStageVlmOptions, self.pipeline_options.vlm_options)
|
||||
twostagevlm_options = cast(
|
||||
TwoStageVlmOptions, self.pipeline_options.vlm_options
|
||||
)
|
||||
|
||||
layout_options = twostagevlm_options.lay_options
|
||||
vlm_options = twostagevlm_options.vlm_options
|
||||
@ -120,7 +122,7 @@ class VlmPipeline(PaginatedPipeline):
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
options=layout_options,
|
||||
)
|
||||
|
||||
|
||||
if vlm_options.inference_framework == InferenceFramework.MLX:
|
||||
vlm_model = HuggingFaceMlxModel(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
@ -145,7 +147,7 @@ class VlmPipeline(PaginatedPipeline):
|
||||
raise ValueError(
|
||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||
)
|
||||
|
||||
|
||||
self.enrichment_pipe = [
|
||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user