diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index 99200c56..07f01f04 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -27,6 +27,7 @@ class ResponseFormat(str, Enum): class InferenceFramework(str, Enum): MLX = "mlx" TRANSFORMERS = "transformers" + VLLM = "vllm" class TransformersModelType(str, Enum): diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index f2ea206d..0abb1642 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -44,6 +44,20 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions( temperature=0.0, ) +SMOLDOCLING_VLLM = InlineVlmOptions( + repo_id="ds4sd/SmolDocling-256M-preview", + prompt="Convert this page to docling.", + response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.VLLM, + transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + ], + scale=2.0, + temperature=0.0, +) + # GraniteVision GRANITE_VISION_TRANSFORMERS = InlineVlmOptions( repo_id="ibm-granite/granite-vision-3.2-2b", @@ -60,6 +74,20 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions( temperature=0.0, ) +GRANITE_VISION_VLLM = InlineVlmOptions( + repo_id="ibm-granite/granite-vision-3.2-2b", + prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!", + response_format=ResponseFormat.MARKDOWN, + inference_framework=InferenceFramework.VLLM, + transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + ], + scale=2.0, + temperature=0.0, +) + GRANITE_VISION_OLLAMA = ApiVlmOptions( url=AnyUrl("http://localhost:11434/v1/chat/completions"), params={"model": "granite3.2-vision:2b"}, @@ -158,5 +186,7 @@ DOLPHIN_TRANSFORMERS = InlineVlmOptions( class VlmModelType(str, Enum): SMOLDOCLING = "smoldocling" + SMOLDOCLING_VLLM = "smoldocling_vllm" GRANITE_VISION = "granite_vision" + GRANITE_VISION_VLLM = "granite_vision_vllm" GRANITE_VISION_OLLAMA = "granite_vision_ollama" diff --git a/docling/models/base_model.py b/docling/models/base_model.py index fbfd5880..9ce1d753 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -37,9 +37,21 @@ class BaseVlmModel(ABC): @abstractmethod def process_images( - self, image_batch: Iterable[Union[Image, np.ndarray]] + self, + image_batch: Iterable[Union[Image, np.ndarray]], + prompt: Union[str, list[str]], ) -> Iterable[VlmPrediction]: - """Process raw images without page metadata.""" + """Process raw images without page metadata. + + Args: + image_batch: Iterable of PIL Images or numpy arrays + prompt: Either: + - str: Single prompt used for all images + - list[str]: List of prompts (one per image, must match image count) + + Raises: + ValueError: If prompt list length doesn't match image count. + """ class BaseVlmPageModel(BasePageModel, BaseVlmModel): @@ -55,23 +67,6 @@ class BaseVlmPageModel(BasePageModel, BaseVlmModel): ) -> Iterable[Page]: """Extract images from pages, process them, and attach results back.""" - @abstractmethod - def process_images( - self, - image_batch: Iterable[Union[Image, np.ndarray]], - prompt: Optional[str] = None, - ) -> Iterable[VlmPrediction]: - """Process raw images without page metadata. - - Args: - image_batch: Iterable of PIL Images or numpy arrays - prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string. - If vlm_options.prompt is callable and no prompt is provided, raises ValueError. - - Raises: - ValueError: If vlm_options.prompt is callable and no prompt parameter is provided. - """ - EnrichElementT = TypeVar("EnrichElementT", default=NodeItem) diff --git a/docling/models/vlm_models_inline/__init__.py b/docling/models/vlm_models_inline/__init__.py index e69de29b..8b137891 100644 --- a/docling/models/vlm_models_inline/__init__.py +++ b/docling/models/vlm_models_inline/__init__.py @@ -0,0 +1 @@ + diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 96822679..e81eafc3 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -125,56 +125,60 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: - for page in page_batch: + page_list = list(page_batch) + if not page_list: + return + + valid_pages = [] + invalid_pages = [] + + for page in page_list: assert page._backend is not None if not page._backend.is_valid(): - yield page + invalid_pages.append(page) else: - with TimeRecorder(conv_res, "vlm"): - assert page.size is not None + valid_pages.append(page) + # Process valid pages in batch + if valid_pages: + with TimeRecorder(conv_res, "vlm"): + # Prepare images and prompts for batch processing + images = [] + user_prompts = [] + pages_with_images = [] + + for page in valid_pages: + assert page.size is not None hi_res_image = page.get_image( scale=self.vlm_options.scale, max_size=self.vlm_options.max_size ) - # Define prompt structure - if callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt(page.parsed_page) - else: - user_prompt = self.vlm_options.prompt - prompt = self.formulate_prompt(user_prompt) + # Only process pages with valid images + if hi_res_image is not None: + images.append(hi_res_image) - inputs = self.processor( - text=prompt, images=[hi_res_image], return_tensors="pt" - ).to(self.device) + # Define prompt structure + if callable(self.vlm_options.prompt): + user_prompt = self.vlm_options.prompt(page.parsed_page) + else: + user_prompt = self.vlm_options.prompt - start_time = time.time() - # Call model to generate: - generated_ids = self.vlm_model.generate( - **inputs, - max_new_tokens=self.max_new_tokens, - use_cache=self.use_cache, - temperature=self.temperature, - generation_config=self.generation_config, - **self.vlm_options.extra_generation_config, - ) + user_prompts.append(user_prompt) + pages_with_images.append(page) - generation_time = time.time() - start_time - generated_texts = self.processor.batch_decode( - generated_ids[:, inputs["input_ids"].shape[1] :], - skip_special_tokens=True, - )[0] + # Use process_images for the actual inference + if images: # Only if we have valid images + predictions = list(self.process_images(images, user_prompts)) - num_tokens = len(generated_ids[0]) - _log.debug( - f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." - ) - page.predictions.vlm_response = VlmPrediction( - text=generated_texts, - generation_time=generation_time, - ) + # Attach results to pages + for page, prediction in zip(pages_with_images, predictions): + page.predictions.vlm_response = prediction - yield page + # Yield all pages (valid and invalid) + for page in invalid_pages: + yield page + for page in valid_pages: + yield page def formulate_prompt(self, user_prompt: str) -> str: """Formulate a prompt for the VLM.""" @@ -221,9 +225,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload def process_images( self, image_batch: Iterable[Union[Image, np.ndarray]], - prompt: Optional[str] = None, + prompt: Union[str, list[str]], ) -> Iterable[VlmPrediction]: - """Process raw images without page metadata in a single batched inference call.""" + """Process raw images without page metadata in a single batched inference call. + + Args: + image_batch: Iterable of PIL Images or numpy arrays + prompt: Either: + - str: Single prompt used for all images + - list[str]: List of prompts (one per image, must match image count) + + Raises: + ValueError: If prompt list length doesn't match image count. + """ pil_images: list[Image] = [] for img in image_batch: @@ -251,19 +265,24 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload if len(pil_images) == 0: return - # Handle prompt with priority: parameter > vlm_options.prompt > error - if prompt is not None: - user_prompt = prompt - elif not callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt + # Handle prompt parameter + if isinstance(prompt, str): + # Single prompt for all images + user_prompts = [prompt] * len(pil_images) + elif isinstance(prompt, list): + # List of prompts (one per image) + if len(prompt) != len(pil_images): + raise ValueError( + f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})" + ) + user_prompts = prompt else: - raise ValueError( - "vlm_options.prompt is callable but no prompt parameter provided to process_images. " - "Please provide a prompt parameter when calling process_images directly." - ) + raise ValueError(f"prompt must be str or list[str], got {type(prompt)}") - formatted_prompt = self.formulate_prompt(user_prompt) - prompts: list[str] = [formatted_prompt] * len(pil_images) + # Format prompts individually + prompts: list[str] = [ + self.formulate_prompt(user_prompt) for user_prompt in user_prompts + ] inputs = self.processor( text=prompts, images=pil_images, return_tensors="pt", padding=True diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 294fe9cc..60894869 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -71,110 +71,103 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: - for page in page_batch: + page_list = list(page_batch) + if not page_list: + return + + valid_pages = [] + invalid_pages = [] + + for page in page_list: assert page._backend is not None if not page._backend.is_valid(): - yield page + invalid_pages.append(page) else: - with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"): - assert page.size is not None + valid_pages.append(page) + # Process valid pages in batch + if valid_pages: + with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"): + # Prepare images and prompts for batch processing + images = [] + user_prompts = [] + pages_with_images = [] + + for page in valid_pages: + assert page.size is not None hi_res_image = page.get_image( scale=self.vlm_options.scale, max_size=self.vlm_options.max_size ) + + # Only process pages with valid images if hi_res_image is not None: - im_width, im_height = hi_res_image.size + images.append(hi_res_image) - # populate page_tags with predicted doc tags - page_tags = "" + # Define prompt structure + if callable(self.vlm_options.prompt): + user_prompt = self.vlm_options.prompt(page.parsed_page) + else: + user_prompt = self.vlm_options.prompt - if hi_res_image: - if hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") + user_prompts.append(user_prompt) + pages_with_images.append(page) - if callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt(page.parsed_page) - else: - user_prompt = self.vlm_options.prompt - prompt = self.apply_chat_template( - self.processor, self.config, user_prompt, num_images=1 - ) + # Use process_images for the actual inference + if images: # Only if we have valid images + predictions = list(self.process_images(images, user_prompts)) - # MLX models are not thread-safe - use global lock to serialize access - with _MLX_GLOBAL_LOCK: - _log.debug( - "MLX model: Acquired global lock for __call__ method" - ) - start_time = time.time() - _log.debug("start generating ...") + # Attach results to pages + for page, prediction in zip(pages_with_images, predictions): + page.predictions.vlm_response = prediction - # Call model to generate: - tokens: list[VlmPredictionToken] = [] - - output = "" - for token in self.stream_generate( - self.vlm_model, - self.processor, - prompt, - [hi_res_image], - max_tokens=self.max_tokens, - verbose=False, - temp=self.temperature, - ): - if len(token.logprobs.shape) == 1: - tokens.append( - VlmPredictionToken( - text=token.text, - token=token.token, - logprob=token.logprobs[token.token], - ) - ) - elif ( - len(token.logprobs.shape) == 2 - and token.logprobs.shape[0] == 1 - ): - tokens.append( - VlmPredictionToken( - text=token.text, - token=token.token, - logprob=token.logprobs[0, token.token], - ) - ) - else: - _log.warning( - f"incompatible shape for logprobs: {token.logprobs.shape}" - ) - - output += token.text - if "" in token.text: - break - - generation_time = time.time() - start_time - _log.debug("MLX model: Released global lock") - page_tags = output - - _log.debug( - f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)." - ) - page.predictions.vlm_response = VlmPrediction( - text=page_tags, - generation_time=generation_time, - generated_tokens=tokens, - ) - - yield page + # Yield all pages (valid and invalid) + for page in invalid_pages: + yield page + for page in valid_pages: + yield page def process_images( self, image_batch: Iterable[Union[Image, np.ndarray]], - prompt: Optional[str] = None, + prompt: Union[str, list[str]], ) -> Iterable[VlmPrediction]: + """Process raw images without page metadata. + + Args: + image_batch: Iterable of PIL Images or numpy arrays + prompt: Either: + - str: Single prompt used for all images + - list[str]: List of prompts (one per image, must match image count) + + Raises: + ValueError: If prompt list length doesn't match image count. + """ from mlx_vlm import generate + # Convert image batch to list for length validation + image_list = list(image_batch) + + if len(image_list) == 0: + return + + # Handle prompt parameter + if isinstance(prompt, str): + # Single prompt for all images + user_prompts = [prompt] * len(image_list) + elif isinstance(prompt, list): + # List of prompts (one per image) + if len(prompt) != len(image_list): + raise ValueError( + f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})" + ) + user_prompts = prompt + else: + raise ValueError(f"prompt must be str or list[str], got {type(prompt)}") + # MLX models are not thread-safe - use global lock to serialize access with _MLX_GLOBAL_LOCK: _log.debug("MLX model: Acquired global lock for thread safety") - for image in image_batch: + for image, user_prompt in zip(image_list, user_prompts): # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): if image.ndim == 3 and image.shape[2] in [3, 4]: @@ -196,17 +189,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): if image.mode != "RGB": image = image.convert("RGB") - # Handle prompt with priority: parameter > vlm_options.prompt > error - if prompt is not None: - user_prompt = prompt - elif not callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt - else: - raise ValueError( - "vlm_options.prompt is callable but no prompt parameter provided to process_images. " - "Please provide a prompt parameter when calling process_images directly." - ) - # Use the MLX chat template approach like in the __call__ method formatted_prompt = self.apply_chat_template( self.processor, self.config, user_prompt, num_images=1 diff --git a/docling/models/vlm_models_inline/vllm_model.py b/docling/models/vlm_models_inline/vllm_model.py new file mode 100644 index 00000000..61c84cde --- /dev/null +++ b/docling/models/vlm_models_inline/vllm_model.py @@ -0,0 +1,277 @@ +import logging +import time +from collections.abc import Iterable +from pathlib import Path +from typing import Any, Optional, Union + +import numpy as np +from PIL.Image import Image + +from docling.datamodel.accelerator_options import ( + AcceleratorOptions, +) +from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options_vlm_model import ( + InlineVlmOptions, + TransformersPromptStyle, +) +from docling.models.base_model import BaseVlmPageModel +from docling.models.utils.hf_model_download import ( + HuggingFaceModelDownloadMixin, +) +from docling.utils.accelerator_utils import decide_device +from docling.utils.profiling import TimeRecorder + +_log = logging.getLogger(__name__) + + +class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + accelerator_options: AcceleratorOptions, + vlm_options: InlineVlmOptions, + ): + self.enabled = enabled + + self.vlm_options = vlm_options + + if self.enabled: + from transformers import AutoProcessor + from vllm import LLM, SamplingParams + + self.device = decide_device( + accelerator_options.device, + supported_devices=vlm_options.supported_devices, + ) + _log.debug(f"Available device for VLM: {self.device}") + + self.max_new_tokens = vlm_options.max_new_tokens + self.temperature = vlm_options.temperature + + repo_cache_folder = vlm_options.repo_id.replace("/", "--") + + if artifacts_path is None: + artifacts_path = self.download_models(self.vlm_options.repo_id) + elif (artifacts_path / repo_cache_folder).exists(): + artifacts_path = artifacts_path / repo_cache_folder + + # Initialize VLLM LLM + llm_kwargs = { + "model": str(artifacts_path), + "model_impl": "transformers", + "limit_mm_per_prompt": {"image": 1}, + "trust_remote_code": vlm_options.trust_remote_code, + } + + # Add device-specific configurations + if self.device.startswith("cuda"): + # VLLM automatically detects GPU + pass + elif self.device == "cpu": + llm_kwargs["device"] = "cpu" + + # Add quantization if specified + if vlm_options.quantized: + if vlm_options.load_in_8bit: + llm_kwargs["quantization"] = "bitsandbytes" + + self.llm = LLM(**llm_kwargs) + + # Initialize processor for prompt formatting + self.processor = AutoProcessor.from_pretrained( + artifacts_path, + trust_remote_code=vlm_options.trust_remote_code, + ) + + # Set up sampling parameters + self.sampling_params = SamplingParams( + temperature=self.temperature, + max_tokens=self.max_new_tokens, + stop=vlm_options.stop_strings if vlm_options.stop_strings else None, + **vlm_options.extra_generation_config, + ) + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + page_list = list(page_batch) + if not page_list: + return + + valid_pages = [] + invalid_pages = [] + + for page in page_list: + assert page._backend is not None + if not page._backend.is_valid(): + invalid_pages.append(page) + else: + valid_pages.append(page) + + # Process valid pages in batch + if valid_pages: + with TimeRecorder(conv_res, "vlm"): + # Prepare images and prompts for batch processing + images = [] + user_prompts = [] + pages_with_images = [] + + for page in valid_pages: + assert page.size is not None + hi_res_image = page.get_image( + scale=self.vlm_options.scale, max_size=self.vlm_options.max_size + ) + + # Only process pages with valid images + if hi_res_image is not None: + images.append(hi_res_image) + + # Define prompt structure + if callable(self.vlm_options.prompt): + user_prompt = self.vlm_options.prompt(page.parsed_page) + else: + user_prompt = self.vlm_options.prompt + + user_prompts.append(user_prompt) + pages_with_images.append(page) + + # Use process_images for the actual inference + if images: # Only if we have valid images + predictions = list(self.process_images(images, user_prompts)) + + # Attach results to pages + for page, prediction in zip(pages_with_images, predictions): + page.predictions.vlm_response = prediction + + # Yield all pages (valid and invalid) + for page in invalid_pages: + yield page + for page in valid_pages: + yield page + + def formulate_prompt(self, user_prompt: str) -> str: + """Formulate a prompt for the VLM.""" + + if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW: + return user_prompt + + elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct": + _log.debug("Using specialized prompt for Phi-4") + # Note: This might need adjustment for VLLM vs transformers + user_prompt_prefix = "<|user|>" + assistant_prompt = "<|assistant|>" + prompt_suffix = "<|end|>" + + prompt = f"{user_prompt_prefix}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}" + _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}") + + return prompt + + elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT: + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "This is a page from a document.", + }, + {"type": "image"}, + {"type": "text", "text": user_prompt}, + ], + } + ] + prompt = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return prompt + + raise RuntimeError( + f"Unknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}." + ) + + def process_images( + self, + image_batch: Iterable[Union[Image, np.ndarray]], + prompt: Union[str, list[str]], + ) -> Iterable[VlmPrediction]: + """Process raw images without page metadata in a single batched inference call. + + Args: + image_batch: Iterable of PIL Images or numpy arrays + prompt: Either: + - str: Single prompt used for all images + - list[str]: List of prompts (one per image, must match image count) + + Raises: + ValueError: If prompt list length doesn't match image count. + """ + pil_images: list[Image] = [] + + for img in image_batch: + # Convert numpy array to PIL Image if needed + if isinstance(img, np.ndarray): + if img.ndim == 3 and img.shape[2] in [3, 4]: + from PIL import Image as PILImage + + pil_img = PILImage.fromarray(img.astype(np.uint8)) + elif img.ndim == 2: + from PIL import Image as PILImage + + pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L") + else: + raise ValueError(f"Unsupported numpy array shape: {img.shape}") + else: + pil_img = img + + # Ensure image is in RGB mode (handles RGBA, L, etc.) + if pil_img.mode != "RGB": + pil_img = pil_img.convert("RGB") + + pil_images.append(pil_img) + + if len(pil_images) == 0: + return + + # Handle prompt parameter + if isinstance(prompt, str): + # Single prompt for all images + user_prompts = [prompt] * len(pil_images) + elif isinstance(prompt, list): + # List of prompts (one per image) + if len(prompt) != len(pil_images): + raise ValueError( + f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})" + ) + user_prompts = prompt + else: + raise ValueError(f"prompt must be str or list[str], got {type(prompt)}") + + # Format prompts individually + prompts: list[str] = [ + self.formulate_prompt(user_prompt) for user_prompt in user_prompts + ] + + # Prepare VLLM inputs + llm_inputs = [] + for prompt, image in zip(prompts, pil_images): + llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) + + start_time = time.time() + outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) + generation_time = time.time() - start_time + + # Logging tokens count for the first sample as a representative metric + if len(outputs) > 0: + num_tokens = len(outputs[0].outputs[0].token_ids) + _log.debug( + f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." + ) + + for output in outputs: + yield VlmPrediction( + text=output.outputs[0].text, generation_time=generation_time + ) diff --git a/docling/pipeline/threaded_multistage_vlm_pipeline.py b/docling/pipeline/threaded_multistage_vlm_pipeline.py index 73c663c8..3ea2483a 100644 --- a/docling/pipeline/threaded_multistage_vlm_pipeline.py +++ b/docling/pipeline/threaded_multistage_vlm_pipeline.py @@ -693,6 +693,17 @@ class ThreadedMultiStageVlmPipeline(BasePipeline): accelerator_options=self.pipeline_options.accelerator_options, vlm_options=vlm_options, ) + elif vlm_options.inference_framework == InferenceFramework.VLLM: + from docling.models.vlm_models_inline.vllm_model import ( + VllmVlmModel, + ) + + model = VllmVlmModel( + enabled=True, + artifacts_path=art_path, + accelerator_options=self.pipeline_options.accelerator_options, + vlm_options=vlm_options, + ) else: raise ValueError( f"Unsupported inference framework: {vlm_options.inference_framework}" diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index ab474fab..5a05b9a0 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -103,6 +103,17 @@ class VlmPipeline(PaginatedPipeline): vlm_options=vlm_options, ), ] + elif vlm_options.inference_framework == InferenceFramework.VLLM: + from docling.models.vlm_models_inline.vllm_model import VllmVlmModel + + self.build_pipe = [ + VllmVlmModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + vlm_options=vlm_options, + ), + ] else: raise ValueError( f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}" diff --git a/pyproject.toml b/pyproject.toml index 678011f2..43ad8eee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,7 @@ module = [ "huggingface_hub.*", "transformers.*", "pylatexenc.*", + "vllm.*", ] ignore_missing_imports = true