mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
feat: working on a two stage VLM model
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
4eceefa47c
commit
810446c8dc
@ -16,6 +16,7 @@ from docling.datamodel import asr_model_specs
|
|||||||
|
|
||||||
# Import the following for backwards compatibility
|
# Import the following for backwards compatibility
|
||||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||||
|
from docling.datamodel.asr_model_specs import WHISPER_TINY as whisper_tiny
|
||||||
from docling.datamodel.pipeline_options_asr_model import (
|
from docling.datamodel.pipeline_options_asr_model import (
|
||||||
InlineAsrOptions,
|
InlineAsrOptions,
|
||||||
)
|
)
|
||||||
|
@ -6,6 +6,7 @@ from pydantic import AnyUrl, BaseModel
|
|||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import AcceleratorDevice
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||||
|
from docling.datamodel.pipeline_options import LayoutOptions
|
||||||
|
|
||||||
|
|
||||||
class BaseVlmOptions(BaseModel):
|
class BaseVlmOptions(BaseModel):
|
||||||
@ -90,7 +91,7 @@ class ApiVlmOptions(BaseVlmOptions):
|
|||||||
|
|
||||||
|
|
||||||
class TwoStageVlmOptions(BaseVlmOptions):
|
class TwoStageVlmOptions(BaseVlmOptions):
|
||||||
kind: Literal["inline_model_options"] = "inline_two_stage_model_options"
|
kind: Literal["inline_two_stage_model_options"] = "inline_two_stage_model_options"
|
||||||
|
|
||||||
vlm_options: UnionInlineVlmOptions
|
vlm_options: InlineVlmOptions
|
||||||
layout_options: LayoutOptions
|
layout_options: LayoutOptions
|
||||||
|
@ -3,6 +3,7 @@ from collections.abc import Iterable
|
|||||||
from typing import Generic, Optional, Protocol, Type
|
from typing import Generic, Optional, Protocol, Type
|
||||||
|
|
||||||
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
||||||
|
from PIL import Image
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
||||||
@ -19,12 +20,25 @@ class BaseModelWithOptions(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class BasePageModel(ABC):
|
class BasePageModel(ABC):
|
||||||
|
scale: float # scale with which the page-image needs to be created (dpi = 72*scale)
|
||||||
|
max_size: int # max size of width/height of page-image
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class BaseLayoutModel(BasePageModel):
|
||||||
|
@abstractmethod
|
||||||
|
def predict_on_page_image(self, *, page_image: Image.Image) -> list(Cluster):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BaseVlmModel(BasePageModel):
|
||||||
|
@abstractmethod
|
||||||
|
def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from docling_core.types.doc import DocItemLabel
|
from docling_core.types.doc import DocItemLabel
|
||||||
|
from docling_core.types.doc.page import TextCell
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||||
@ -176,11 +177,11 @@ class LayoutModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
clusters.append(cluster)
|
clusters.append(cluster)
|
||||||
"""
|
"""
|
||||||
clusters = self.predict_on_page(page_image)
|
predicted_clusters = self.predict_on_page(page_image=page_image)
|
||||||
|
|
||||||
if settings.debug.visualize_raw_layout:
|
if settings.debug.visualize_raw_layout:
|
||||||
self.draw_clusters_and_cells_side_by_side(
|
self.draw_clusters_and_cells_side_by_side(
|
||||||
conv_res, page, clusters, mode_prefix="raw"
|
conv_res, page, predicted_clusters, mode_prefix="raw"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply postprocessing
|
# Apply postprocessing
|
||||||
@ -212,45 +213,10 @@ class LayoutModel(BasePageModel):
|
|||||||
clusters=processed_clusters
|
clusters=processed_clusters
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
page, processed_clusters, processed_cells = self.postprocess_on_page(page, cluster)
|
page, processed_clusters, processed_cells = (
|
||||||
|
self.postprocess_on_page(page=page, clusters=predicted_clusters)
|
||||||
if settings.debug.visualize_layout:
|
|
||||||
self.draw_clusters_and_cells_side_by_side(
|
|
||||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield page
|
|
||||||
|
|
||||||
def predict_on_page(self, page_image: Image) -> list[Cluster]:
|
|
||||||
|
|
||||||
pred_items = self.layout_predictor.predict(page_image)
|
|
||||||
|
|
||||||
clusters = []
|
|
||||||
for ix, pred_item in enumerate(pred_items):
|
|
||||||
label = DocItemLabel(
|
|
||||||
pred_item["label"]
|
|
||||||
.lower()
|
|
||||||
.replace(" ", "_")
|
|
||||||
.replace("-", "_")
|
|
||||||
) # Temporary, until docling-ibm-model uses docling-core types
|
|
||||||
cluster = Cluster(
|
|
||||||
id=ix,
|
|
||||||
label=label,
|
|
||||||
confidence=pred_item["confidence"],
|
|
||||||
bbox=BoundingBox.model_validate(pred_item),
|
|
||||||
cells=[],
|
|
||||||
)
|
|
||||||
clusters.append(cluster)
|
|
||||||
|
|
||||||
return clusters
|
|
||||||
|
|
||||||
def postprocess_on_page(self, page: Page, cluster: list(Cluster)):
|
|
||||||
|
|
||||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
|
||||||
page, clusters, self.options
|
|
||||||
).postprocess()
|
|
||||||
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
@ -269,8 +235,40 @@ class LayoutModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
page.predictions.layout = LayoutPrediction(
|
if settings.debug.visualize_layout:
|
||||||
clusters=processed_clusters
|
self.draw_clusters_and_cells_side_by_side(
|
||||||
|
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield page
|
||||||
|
|
||||||
|
def predict_on_page(self, *, page_image: Image.Image) -> list[Cluster]:
|
||||||
|
pred_items = self.layout_predictor.predict(page_image)
|
||||||
|
|
||||||
|
clusters = []
|
||||||
|
for ix, pred_item in enumerate(pred_items):
|
||||||
|
label = DocItemLabel(
|
||||||
|
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
|
||||||
|
) # Temporary, until docling-ibm-model uses docling-core types
|
||||||
|
cluster = Cluster(
|
||||||
|
id=ix,
|
||||||
|
label=label,
|
||||||
|
confidence=pred_item["confidence"],
|
||||||
|
bbox=BoundingBox.model_validate(pred_item),
|
||||||
|
cells=[],
|
||||||
|
)
|
||||||
|
clusters.append(cluster)
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def postprocess_on_page(
|
||||||
|
self, *, page: Page, clusters: list[Cluster]
|
||||||
|
) -> tuple[Page, list[Cluster], list[TextCell]]:
|
||||||
|
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||||
|
page, clusters, self.options
|
||||||
|
).postprocess()
|
||||||
|
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
||||||
|
|
||||||
|
page.predictions.layout = LayoutPrediction(clusters=processed_clusters)
|
||||||
|
|
||||||
return page, processed_clusters, processed_cells
|
return page, processed_clusters, processed_cells
|
||||||
|
@ -15,7 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
TransformersPromptStyle,
|
TransformersPromptStyle,
|
||||||
)
|
)
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel, BaseVlmModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import (
|
||||||
HuggingFaceModelDownloadMixin,
|
HuggingFaceModelDownloadMixin,
|
||||||
)
|
)
|
||||||
@ -25,7 +25,7 @@ from docling.utils.profiling import TimeRecorder
|
|||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
@ -37,6 +37,9 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
|
|
||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
|
self.scale = self.vlm_options.scale
|
||||||
|
self.max_size = self.vlm_options.max_size
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
@ -10,7 +10,7 @@ from docling.datamodel.accelerator_options import (
|
|||||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel, BaseVlmModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import (
|
||||||
HuggingFaceModelDownloadMixin,
|
HuggingFaceModelDownloadMixin,
|
||||||
)
|
)
|
||||||
@ -19,7 +19,7 @@ from docling.utils.profiling import TimeRecorder
|
|||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
@ -28,10 +28,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
vlm_options: InlineVlmOptions,
|
vlm_options: InlineVlmOptions,
|
||||||
):
|
):
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
|
|
||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
self.max_tokens = vlm_options.max_new_tokens
|
self.max_tokens = vlm_options.max_new_tokens
|
||||||
self.temperature = vlm_options.temperature
|
self.temperature = vlm_options.temperature
|
||||||
|
self.scale = self.vlm_options.scale
|
||||||
|
self.max_size = self.vlm_options.max_size
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
try:
|
try:
|
||||||
|
@ -15,7 +15,8 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
TransformersPromptStyle,
|
TransformersPromptStyle,
|
||||||
)
|
)
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel, BaseVlmModel
|
||||||
|
from docling.models.layout_model import LayoutModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import (
|
||||||
HuggingFaceModelDownloadMixin,
|
HuggingFaceModelDownloadMixin,
|
||||||
)
|
)
|
||||||
@ -29,11 +30,11 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
layout_model: LayoutModelModel,
|
layout_model: LayoutModel,
|
||||||
vlm_model: BasePageModel,
|
vlm_model: BaseVlmModel,
|
||||||
):
|
):
|
||||||
self.layout_model = layout_options
|
self.layout_model = layout_model
|
||||||
self.vlm_model = vlm_options
|
self.vlm_model = vlm_model
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
@ -47,12 +48,15 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
assert page.size is not None
|
assert page.size is not None
|
||||||
|
|
||||||
page_image = page.get_image(
|
page_image = page.get_image(
|
||||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
scale=self.vlm_model.scale, max_size=self.vlm_model.max_size
|
||||||
)
|
)
|
||||||
|
|
||||||
pred_clusters = self.layout_model.predict_on_page(page_image)
|
pred_clusters = self.layout_model.predict_on_page(page_image=page_image)
|
||||||
page, processed_clusters, processed_cells = self.layout_model.postprocess_on_page(page=page,
|
page, processed_clusters, processed_cells = (
|
||||||
page_image=page_image)
|
self.layout_model.postprocess_on_page(
|
||||||
|
page=page, clusters=pred_clusters
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
if callable(self.vlm_options.prompt):
|
if callable(self.vlm_options.prompt):
|
||||||
@ -62,8 +66,9 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
|
|
||||||
prompt = self.formulate_prompt(user_prompt, processed_clusters)
|
prompt = self.formulate_prompt(user_prompt, processed_clusters)
|
||||||
|
|
||||||
generated_text, generation_time = self.vlm_model.predict_on_image(page_image=page_image,
|
generated_text, generation_time = self.vlm_model.predict_on_image(
|
||||||
prompt=prompt)
|
page_image=page_image, prompt=prompt
|
||||||
|
)
|
||||||
|
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=generated_text,
|
text=generated_text,
|
||||||
@ -72,7 +77,7 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
def formulate_prompt(self, user_prompt: str, clusters:list[Cluster]) -> str:
|
def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str:
|
||||||
"""Formulate a prompt for the VLM."""
|
"""Formulate a prompt for the VLM."""
|
||||||
|
|
||||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
@ -26,9 +26,7 @@ from docling.backend.md_backend import MarkdownDocumentBackend
|
|||||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||||
from docling.datamodel.base_models import InputFormat, Page
|
from docling.datamodel.base_models import InputFormat, Page
|
||||||
from docling.datamodel.document import ConversionResult, InputDocument
|
from docling.datamodel.document import ConversionResult, InputDocument
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions
|
||||||
VlmPipelineOptions,
|
|
||||||
)
|
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
ApiVlmOptions,
|
ApiVlmOptions,
|
||||||
InferenceFramework,
|
InferenceFramework,
|
||||||
@ -37,15 +35,17 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.api_vlm_model import ApiVlmModel
|
from docling.models.api_vlm_model import ApiVlmModel
|
||||||
|
from docling.models.layout_model import LayoutModel
|
||||||
from docling.models.vlm_models_inline.hf_transformers_model import (
|
from docling.models.vlm_models_inline.hf_transformers_model import (
|
||||||
HuggingFaceTransformersVlmModel,
|
HuggingFaceTransformersVlmModel,
|
||||||
)
|
)
|
||||||
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
|
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
|
||||||
|
from docling.models.vlm_models_inline.two_stage_vlm_model import (
|
||||||
|
TwoStageVlmModel,
|
||||||
|
)
|
||||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
from docling.models.layout_model import LayoutModel
|
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -110,7 +110,9 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||||
)
|
)
|
||||||
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
|
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
|
||||||
twostagevlm_options = cast(TwoStageVlmOptions, self.pipeline_options.vlm_options)
|
twostagevlm_options = cast(
|
||||||
|
TwoStageVlmOptions, self.pipeline_options.vlm_options
|
||||||
|
)
|
||||||
|
|
||||||
layout_options = twostagevlm_options.lay_options
|
layout_options = twostagevlm_options.lay_options
|
||||||
vlm_options = twostagevlm_options.vlm_options
|
vlm_options = twostagevlm_options.vlm_options
|
||||||
|
Loading…
Reference in New Issue
Block a user