mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
working two-stage vlm approach from the cli
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
fb74d0c5b3
commit
b2d5c783ae
@ -63,6 +63,7 @@ from docling.datamodel.vlm_model_specs import (
|
|||||||
GRANITE_VISION_TRANSFORMERS,
|
GRANITE_VISION_TRANSFORMERS,
|
||||||
SMOLDOCLING_MLX,
|
SMOLDOCLING_MLX,
|
||||||
SMOLDOCLING_TRANSFORMERS,
|
SMOLDOCLING_TRANSFORMERS,
|
||||||
|
VLM2STAGE,
|
||||||
VlmModelType,
|
VlmModelType,
|
||||||
)
|
)
|
||||||
from docling.document_converter import (
|
from docling.document_converter import (
|
||||||
@ -627,6 +628,12 @@ def convert( # noqa: C901
|
|||||||
"To run SmolDocling faster, please install mlx-vlm:\n"
|
"To run SmolDocling faster, please install mlx-vlm:\n"
|
||||||
"pip install mlx-vlm"
|
"pip install mlx-vlm"
|
||||||
)
|
)
|
||||||
|
elif vlm_model == VlmModelType.VLM2STAGE:
|
||||||
|
pipeline_options.vlm_options = VLM2STAGE
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{vlm_model} is not of type GRANITE_VISION, GRANITE_VISION_OLLAMA, SMOLDOCLING_TRANSFORMERS or VLM2STAGE"
|
||||||
|
)
|
||||||
|
|
||||||
pdf_format_option = PdfFormatOption(
|
pdf_format_option = PdfFormatOption(
|
||||||
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
||||||
|
@ -153,4 +153,4 @@ class VlmModelType(str, Enum):
|
|||||||
SMOLDOCLING = "smoldocling"
|
SMOLDOCLING = "smoldocling"
|
||||||
GRANITE_VISION = "granite_vision"
|
GRANITE_VISION = "granite_vision"
|
||||||
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
||||||
VLM2STAGE = "docling2stage"
|
VLM2STAGE = "vlm2stage"
|
||||||
|
@ -40,7 +40,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi
|
|||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
self.scale = self.vlm_options.scale
|
self.scale = self.vlm_options.scale
|
||||||
# self.max_size = self.vlm_options.max_size
|
self.max_size = 512
|
||||||
|
if isinstance(self.vlm_options.max_size, int):
|
||||||
|
self.max_size = self.vlm_options.max_size
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
import torch
|
import torch
|
||||||
|
@ -35,7 +35,10 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
self.max_tokens = vlm_options.max_new_tokens
|
self.max_tokens = vlm_options.max_new_tokens
|
||||||
self.temperature = vlm_options.temperature
|
self.temperature = vlm_options.temperature
|
||||||
self.scale = self.vlm_options.scale
|
self.scale = self.vlm_options.scale
|
||||||
# self.max_size = self.vlm_options.max_size
|
|
||||||
|
self.max_size = 512
|
||||||
|
if isinstance(self.vlm_options.max_size, int):
|
||||||
|
self.max_size = self.vlm_options.max_size
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
try:
|
try:
|
||||||
|
@ -50,12 +50,12 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
page_image = page.get_image(
|
page_image = page.get_image(
|
||||||
scale=self.vlm_model.scale, max_size=self.vlm_model.max_size
|
scale=self.vlm_model.scale, max_size=self.vlm_model.max_size
|
||||||
)
|
)
|
||||||
|
|
||||||
assert page_image is not None
|
assert page_image is not None
|
||||||
|
|
||||||
pred_clusters = self.layout_model.predict_on_page_image(
|
pred_clusters = self.layout_model.predict_on_page_image(
|
||||||
page_image=page_image
|
page_image=page_image
|
||||||
)
|
)
|
||||||
|
|
||||||
page, processed_clusters, processed_cells = (
|
page, processed_clusters, processed_cells = (
|
||||||
self.layout_model.postprocess_on_page_image(
|
self.layout_model.postprocess_on_page_image(
|
||||||
page=page, clusters=pred_clusters
|
page=page, clusters=pred_clusters
|
||||||
@ -68,14 +68,17 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
generated_text = self.vlm_model.predict_on_page_image(
|
generated_text, generated_tokens = (
|
||||||
page_image=page_image, prompt=prompt
|
self.vlm_model.predict_on_page_image(
|
||||||
|
page_image=page_image, prompt=prompt
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=generated_text, generation_time=time.time() - start_time
|
text=generated_text,
|
||||||
|
generation_time=time.time() - start_time,
|
||||||
|
generated_tokens=generated_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
def formulate_prompt(self, *, user_prompt: str, clusters: list[Cluster]) -> str:
|
def formulate_prompt(self, *, user_prompt: str, clusters: list[Cluster]) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user