mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 20:58:11 +00:00
Add VLLM backend support, optimize process_images
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -27,6 +27,7 @@ class ResponseFormat(str, Enum):
|
|||||||
class InferenceFramework(str, Enum):
|
class InferenceFramework(str, Enum):
|
||||||
MLX = "mlx"
|
MLX = "mlx"
|
||||||
TRANSFORMERS = "transformers"
|
TRANSFORMERS = "transformers"
|
||||||
|
VLLM = "vllm"
|
||||||
|
|
||||||
|
|
||||||
class TransformersModelType(str, Enum):
|
class TransformersModelType(str, Enum):
|
||||||
|
|||||||
@@ -44,6 +44,20 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SMOLDOCLING_VLLM = InlineVlmOptions(
|
||||||
|
repo_id="ds4sd/SmolDocling-256M-preview",
|
||||||
|
prompt="Convert this page to docling.",
|
||||||
|
response_format=ResponseFormat.DOCTAGS,
|
||||||
|
inference_framework=InferenceFramework.VLLM,
|
||||||
|
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
|
||||||
|
supported_devices=[
|
||||||
|
AcceleratorDevice.CPU,
|
||||||
|
AcceleratorDevice.CUDA,
|
||||||
|
],
|
||||||
|
scale=2.0,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
# GraniteVision
|
# GraniteVision
|
||||||
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
||||||
repo_id="ibm-granite/granite-vision-3.2-2b",
|
repo_id="ibm-granite/granite-vision-3.2-2b",
|
||||||
@@ -60,6 +74,20 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GRANITE_VISION_VLLM = InlineVlmOptions(
|
||||||
|
repo_id="ibm-granite/granite-vision-3.2-2b",
|
||||||
|
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
|
||||||
|
response_format=ResponseFormat.MARKDOWN,
|
||||||
|
inference_framework=InferenceFramework.VLLM,
|
||||||
|
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
|
||||||
|
supported_devices=[
|
||||||
|
AcceleratorDevice.CPU,
|
||||||
|
AcceleratorDevice.CUDA,
|
||||||
|
],
|
||||||
|
scale=2.0,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
GRANITE_VISION_OLLAMA = ApiVlmOptions(
|
GRANITE_VISION_OLLAMA = ApiVlmOptions(
|
||||||
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
|
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
|
||||||
params={"model": "granite3.2-vision:2b"},
|
params={"model": "granite3.2-vision:2b"},
|
||||||
@@ -158,5 +186,7 @@ DOLPHIN_TRANSFORMERS = InlineVlmOptions(
|
|||||||
|
|
||||||
class VlmModelType(str, Enum):
|
class VlmModelType(str, Enum):
|
||||||
SMOLDOCLING = "smoldocling"
|
SMOLDOCLING = "smoldocling"
|
||||||
|
SMOLDOCLING_VLLM = "smoldocling_vllm"
|
||||||
GRANITE_VISION = "granite_vision"
|
GRANITE_VISION = "granite_vision"
|
||||||
|
GRANITE_VISION_VLLM = "granite_vision_vllm"
|
||||||
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
||||||
|
|||||||
@@ -37,9 +37,21 @@ class BaseVlmModel(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process_images(
|
def process_images(
|
||||||
self, image_batch: Iterable[Union[Image, np.ndarray]]
|
self,
|
||||||
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
|
prompt: Union[str, list[str]],
|
||||||
) -> Iterable[VlmPrediction]:
|
) -> Iterable[VlmPrediction]:
|
||||||
"""Process raw images without page metadata."""
|
"""Process raw images without page metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_batch: Iterable of PIL Images or numpy arrays
|
||||||
|
prompt: Either:
|
||||||
|
- str: Single prompt used for all images
|
||||||
|
- list[str]: List of prompts (one per image, must match image count)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt list length doesn't match image count.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
||||||
@@ -55,23 +67,6 @@ class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
|||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
"""Extract images from pages, process them, and attach results back."""
|
"""Extract images from pages, process them, and attach results back."""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process_images(
|
|
||||||
self,
|
|
||||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
|
||||||
prompt: Optional[str] = None,
|
|
||||||
) -> Iterable[VlmPrediction]:
|
|
||||||
"""Process raw images without page metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_batch: Iterable of PIL Images or numpy arrays
|
|
||||||
prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string.
|
|
||||||
If vlm_options.prompt is callable and no prompt is provided, raises ValueError.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If vlm_options.prompt is callable and no prompt parameter is provided.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
|
|||||||
@@ -125,55 +125,59 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
for page in page_batch:
|
page_list = list(page_batch)
|
||||||
|
if not page_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
valid_pages = []
|
||||||
|
invalid_pages = []
|
||||||
|
|
||||||
|
for page in page_list:
|
||||||
assert page._backend is not None
|
assert page._backend is not None
|
||||||
if not page._backend.is_valid():
|
if not page._backend.is_valid():
|
||||||
yield page
|
invalid_pages.append(page)
|
||||||
else:
|
else:
|
||||||
with TimeRecorder(conv_res, "vlm"):
|
valid_pages.append(page)
|
||||||
assert page.size is not None
|
|
||||||
|
|
||||||
|
# Process valid pages in batch
|
||||||
|
if valid_pages:
|
||||||
|
with TimeRecorder(conv_res, "vlm"):
|
||||||
|
# Prepare images and prompts for batch processing
|
||||||
|
images = []
|
||||||
|
user_prompts = []
|
||||||
|
pages_with_images = []
|
||||||
|
|
||||||
|
for page in valid_pages:
|
||||||
|
assert page.size is not None
|
||||||
hi_res_image = page.get_image(
|
hi_res_image = page.get_image(
|
||||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only process pages with valid images
|
||||||
|
if hi_res_image is not None:
|
||||||
|
images.append(hi_res_image)
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
if callable(self.vlm_options.prompt):
|
if callable(self.vlm_options.prompt):
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
else:
|
else:
|
||||||
user_prompt = self.vlm_options.prompt
|
user_prompt = self.vlm_options.prompt
|
||||||
prompt = self.formulate_prompt(user_prompt)
|
|
||||||
|
|
||||||
inputs = self.processor(
|
user_prompts.append(user_prompt)
|
||||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
pages_with_images.append(page)
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
start_time = time.time()
|
# Use process_images for the actual inference
|
||||||
# Call model to generate:
|
if images: # Only if we have valid images
|
||||||
generated_ids = self.vlm_model.generate(
|
predictions = list(self.process_images(images, user_prompts))
|
||||||
**inputs,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
use_cache=self.use_cache,
|
|
||||||
temperature=self.temperature,
|
|
||||||
generation_config=self.generation_config,
|
|
||||||
**self.vlm_options.extra_generation_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
generation_time = time.time() - start_time
|
# Attach results to pages
|
||||||
generated_texts = self.processor.batch_decode(
|
for page, prediction in zip(pages_with_images, predictions):
|
||||||
generated_ids[:, inputs["input_ids"].shape[1] :],
|
page.predictions.vlm_response = prediction
|
||||||
skip_special_tokens=True,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
num_tokens = len(generated_ids[0])
|
|
||||||
_log.debug(
|
|
||||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
|
||||||
)
|
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
|
||||||
text=generated_texts,
|
|
||||||
generation_time=generation_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Yield all pages (valid and invalid)
|
||||||
|
for page in invalid_pages:
|
||||||
|
yield page
|
||||||
|
for page in valid_pages:
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
def formulate_prompt(self, user_prompt: str) -> str:
|
def formulate_prompt(self, user_prompt: str) -> str:
|
||||||
@@ -221,9 +225,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
def process_images(
|
def process_images(
|
||||||
self,
|
self,
|
||||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
prompt: Optional[str] = None,
|
prompt: Union[str, list[str]],
|
||||||
) -> Iterable[VlmPrediction]:
|
) -> Iterable[VlmPrediction]:
|
||||||
"""Process raw images without page metadata in a single batched inference call."""
|
"""Process raw images without page metadata in a single batched inference call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_batch: Iterable of PIL Images or numpy arrays
|
||||||
|
prompt: Either:
|
||||||
|
- str: Single prompt used for all images
|
||||||
|
- list[str]: List of prompts (one per image, must match image count)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt list length doesn't match image count.
|
||||||
|
"""
|
||||||
pil_images: list[Image] = []
|
pil_images: list[Image] = []
|
||||||
|
|
||||||
for img in image_batch:
|
for img in image_batch:
|
||||||
@@ -251,19 +265,24 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
if len(pil_images) == 0:
|
if len(pil_images) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle prompt with priority: parameter > vlm_options.prompt > error
|
# Handle prompt parameter
|
||||||
if prompt is not None:
|
if isinstance(prompt, str):
|
||||||
user_prompt = prompt
|
# Single prompt for all images
|
||||||
elif not callable(self.vlm_options.prompt):
|
user_prompts = [prompt] * len(pil_images)
|
||||||
user_prompt = self.vlm_options.prompt
|
elif isinstance(prompt, list):
|
||||||
else:
|
# List of prompts (one per image)
|
||||||
|
if len(prompt) != len(pil_images):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
|
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||||
"Please provide a prompt parameter when calling process_images directly."
|
|
||||||
)
|
)
|
||||||
|
user_prompts = prompt
|
||||||
|
else:
|
||||||
|
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
||||||
|
|
||||||
formatted_prompt = self.formulate_prompt(user_prompt)
|
# Format prompts individually
|
||||||
prompts: list[str] = [formatted_prompt] * len(pil_images)
|
prompts: list[str] = [
|
||||||
|
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
|
||||||
|
]
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
text=prompts, images=pil_images, return_tensors="pt", padding=True
|
text=prompts, images=pil_images, return_tensors="pt", padding=True
|
||||||
|
|||||||
@@ -71,110 +71,103 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
for page in page_batch:
|
page_list = list(page_batch)
|
||||||
|
if not page_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
valid_pages = []
|
||||||
|
invalid_pages = []
|
||||||
|
|
||||||
|
for page in page_list:
|
||||||
assert page._backend is not None
|
assert page._backend is not None
|
||||||
if not page._backend.is_valid():
|
if not page._backend.is_valid():
|
||||||
yield page
|
invalid_pages.append(page)
|
||||||
else:
|
else:
|
||||||
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
valid_pages.append(page)
|
||||||
assert page.size is not None
|
|
||||||
|
|
||||||
|
# Process valid pages in batch
|
||||||
|
if valid_pages:
|
||||||
|
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
||||||
|
# Prepare images and prompts for batch processing
|
||||||
|
images = []
|
||||||
|
user_prompts = []
|
||||||
|
pages_with_images = []
|
||||||
|
|
||||||
|
for page in valid_pages:
|
||||||
|
assert page.size is not None
|
||||||
hi_res_image = page.get_image(
|
hi_res_image = page.get_image(
|
||||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only process pages with valid images
|
||||||
if hi_res_image is not None:
|
if hi_res_image is not None:
|
||||||
im_width, im_height = hi_res_image.size
|
images.append(hi_res_image)
|
||||||
|
|
||||||
# populate page_tags with predicted doc tags
|
|
||||||
page_tags = ""
|
|
||||||
|
|
||||||
if hi_res_image:
|
|
||||||
if hi_res_image.mode != "RGB":
|
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
|
||||||
|
|
||||||
|
# Define prompt structure
|
||||||
if callable(self.vlm_options.prompt):
|
if callable(self.vlm_options.prompt):
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
else:
|
else:
|
||||||
user_prompt = self.vlm_options.prompt
|
user_prompt = self.vlm_options.prompt
|
||||||
prompt = self.apply_chat_template(
|
|
||||||
self.processor, self.config, user_prompt, num_images=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# MLX models are not thread-safe - use global lock to serialize access
|
user_prompts.append(user_prompt)
|
||||||
with _MLX_GLOBAL_LOCK:
|
pages_with_images.append(page)
|
||||||
_log.debug(
|
|
||||||
"MLX model: Acquired global lock for __call__ method"
|
|
||||||
)
|
|
||||||
start_time = time.time()
|
|
||||||
_log.debug("start generating ...")
|
|
||||||
|
|
||||||
# Call model to generate:
|
# Use process_images for the actual inference
|
||||||
tokens: list[VlmPredictionToken] = []
|
if images: # Only if we have valid images
|
||||||
|
predictions = list(self.process_images(images, user_prompts))
|
||||||
|
|
||||||
output = ""
|
# Attach results to pages
|
||||||
for token in self.stream_generate(
|
for page, prediction in zip(pages_with_images, predictions):
|
||||||
self.vlm_model,
|
page.predictions.vlm_response = prediction
|
||||||
self.processor,
|
|
||||||
prompt,
|
|
||||||
[hi_res_image],
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
verbose=False,
|
|
||||||
temp=self.temperature,
|
|
||||||
):
|
|
||||||
if len(token.logprobs.shape) == 1:
|
|
||||||
tokens.append(
|
|
||||||
VlmPredictionToken(
|
|
||||||
text=token.text,
|
|
||||||
token=token.token,
|
|
||||||
logprob=token.logprobs[token.token],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
len(token.logprobs.shape) == 2
|
|
||||||
and token.logprobs.shape[0] == 1
|
|
||||||
):
|
|
||||||
tokens.append(
|
|
||||||
VlmPredictionToken(
|
|
||||||
text=token.text,
|
|
||||||
token=token.token,
|
|
||||||
logprob=token.logprobs[0, token.token],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_log.warning(
|
|
||||||
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
output += token.text
|
|
||||||
if "</doctag>" in token.text:
|
|
||||||
break
|
|
||||||
|
|
||||||
generation_time = time.time() - start_time
|
|
||||||
_log.debug("MLX model: Released global lock")
|
|
||||||
page_tags = output
|
|
||||||
|
|
||||||
_log.debug(
|
|
||||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
|
||||||
)
|
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
|
||||||
text=page_tags,
|
|
||||||
generation_time=generation_time,
|
|
||||||
generated_tokens=tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Yield all pages (valid and invalid)
|
||||||
|
for page in invalid_pages:
|
||||||
|
yield page
|
||||||
|
for page in valid_pages:
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
def process_images(
|
def process_images(
|
||||||
self,
|
self,
|
||||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
prompt: Optional[str] = None,
|
prompt: Union[str, list[str]],
|
||||||
) -> Iterable[VlmPrediction]:
|
) -> Iterable[VlmPrediction]:
|
||||||
|
"""Process raw images without page metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_batch: Iterable of PIL Images or numpy arrays
|
||||||
|
prompt: Either:
|
||||||
|
- str: Single prompt used for all images
|
||||||
|
- list[str]: List of prompts (one per image, must match image count)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt list length doesn't match image count.
|
||||||
|
"""
|
||||||
from mlx_vlm import generate
|
from mlx_vlm import generate
|
||||||
|
|
||||||
|
# Convert image batch to list for length validation
|
||||||
|
image_list = list(image_batch)
|
||||||
|
|
||||||
|
if len(image_list) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle prompt parameter
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
# Single prompt for all images
|
||||||
|
user_prompts = [prompt] * len(image_list)
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
# List of prompts (one per image)
|
||||||
|
if len(prompt) != len(image_list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
|
||||||
|
)
|
||||||
|
user_prompts = prompt
|
||||||
|
else:
|
||||||
|
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
||||||
|
|
||||||
# MLX models are not thread-safe - use global lock to serialize access
|
# MLX models are not thread-safe - use global lock to serialize access
|
||||||
with _MLX_GLOBAL_LOCK:
|
with _MLX_GLOBAL_LOCK:
|
||||||
_log.debug("MLX model: Acquired global lock for thread safety")
|
_log.debug("MLX model: Acquired global lock for thread safety")
|
||||||
for image in image_batch:
|
for image, user_prompt in zip(image_list, user_prompts):
|
||||||
# Convert numpy array to PIL Image if needed
|
# Convert numpy array to PIL Image if needed
|
||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
||||||
@@ -196,17 +189,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
# Handle prompt with priority: parameter > vlm_options.prompt > error
|
|
||||||
if prompt is not None:
|
|
||||||
user_prompt = prompt
|
|
||||||
elif not callable(self.vlm_options.prompt):
|
|
||||||
user_prompt = self.vlm_options.prompt
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
|
|
||||||
"Please provide a prompt parameter when calling process_images directly."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the MLX chat template approach like in the __call__ method
|
# Use the MLX chat template approach like in the __call__ method
|
||||||
formatted_prompt = self.apply_chat_template(
|
formatted_prompt = self.apply_chat_template(
|
||||||
self.processor, self.config, user_prompt, num_images=1
|
self.processor, self.config, user_prompt, num_images=1
|
||||||
|
|||||||
277
docling/models/vlm_models_inline/vllm_model.py
Normal file
277
docling/models/vlm_models_inline/vllm_model.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
from docling.datamodel.accelerator_options import (
|
||||||
|
AcceleratorOptions,
|
||||||
|
)
|
||||||
|
from docling.datamodel.base_models import Page, VlmPrediction
|
||||||
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
|
InlineVlmOptions,
|
||||||
|
TransformersPromptStyle,
|
||||||
|
)
|
||||||
|
from docling.models.base_model import BaseVlmPageModel
|
||||||
|
from docling.models.utils.hf_model_download import (
|
||||||
|
HuggingFaceModelDownloadMixin,
|
||||||
|
)
|
||||||
|
from docling.utils.accelerator_utils import decide_device
|
||||||
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
artifacts_path: Optional[Path],
|
||||||
|
accelerator_options: AcceleratorOptions,
|
||||||
|
vlm_options: InlineVlmOptions,
|
||||||
|
):
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
self.device = decide_device(
|
||||||
|
accelerator_options.device,
|
||||||
|
supported_devices=vlm_options.supported_devices,
|
||||||
|
)
|
||||||
|
_log.debug(f"Available device for VLM: {self.device}")
|
||||||
|
|
||||||
|
self.max_new_tokens = vlm_options.max_new_tokens
|
||||||
|
self.temperature = vlm_options.temperature
|
||||||
|
|
||||||
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
|
|
||||||
|
if artifacts_path is None:
|
||||||
|
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||||
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
|
# Initialize VLLM LLM
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": str(artifacts_path),
|
||||||
|
"model_impl": "transformers",
|
||||||
|
"limit_mm_per_prompt": {"image": 1},
|
||||||
|
"trust_remote_code": vlm_options.trust_remote_code,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add device-specific configurations
|
||||||
|
if self.device.startswith("cuda"):
|
||||||
|
# VLLM automatically detects GPU
|
||||||
|
pass
|
||||||
|
elif self.device == "cpu":
|
||||||
|
llm_kwargs["device"] = "cpu"
|
||||||
|
|
||||||
|
# Add quantization if specified
|
||||||
|
if vlm_options.quantized:
|
||||||
|
if vlm_options.load_in_8bit:
|
||||||
|
llm_kwargs["quantization"] = "bitsandbytes"
|
||||||
|
|
||||||
|
self.llm = LLM(**llm_kwargs)
|
||||||
|
|
||||||
|
# Initialize processor for prompt formatting
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
artifacts_path,
|
||||||
|
trust_remote_code=vlm_options.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up sampling parameters
|
||||||
|
self.sampling_params = SamplingParams(
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_new_tokens,
|
||||||
|
stop=vlm_options.stop_strings if vlm_options.stop_strings else None,
|
||||||
|
**vlm_options.extra_generation_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
) -> Iterable[Page]:
|
||||||
|
page_list = list(page_batch)
|
||||||
|
if not page_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
valid_pages = []
|
||||||
|
invalid_pages = []
|
||||||
|
|
||||||
|
for page in page_list:
|
||||||
|
assert page._backend is not None
|
||||||
|
if not page._backend.is_valid():
|
||||||
|
invalid_pages.append(page)
|
||||||
|
else:
|
||||||
|
valid_pages.append(page)
|
||||||
|
|
||||||
|
# Process valid pages in batch
|
||||||
|
if valid_pages:
|
||||||
|
with TimeRecorder(conv_res, "vlm"):
|
||||||
|
# Prepare images and prompts for batch processing
|
||||||
|
images = []
|
||||||
|
user_prompts = []
|
||||||
|
pages_with_images = []
|
||||||
|
|
||||||
|
for page in valid_pages:
|
||||||
|
assert page.size is not None
|
||||||
|
hi_res_image = page.get_image(
|
||||||
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only process pages with valid images
|
||||||
|
if hi_res_image is not None:
|
||||||
|
images.append(hi_res_image)
|
||||||
|
|
||||||
|
# Define prompt structure
|
||||||
|
if callable(self.vlm_options.prompt):
|
||||||
|
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
|
else:
|
||||||
|
user_prompt = self.vlm_options.prompt
|
||||||
|
|
||||||
|
user_prompts.append(user_prompt)
|
||||||
|
pages_with_images.append(page)
|
||||||
|
|
||||||
|
# Use process_images for the actual inference
|
||||||
|
if images: # Only if we have valid images
|
||||||
|
predictions = list(self.process_images(images, user_prompts))
|
||||||
|
|
||||||
|
# Attach results to pages
|
||||||
|
for page, prediction in zip(pages_with_images, predictions):
|
||||||
|
page.predictions.vlm_response = prediction
|
||||||
|
|
||||||
|
# Yield all pages (valid and invalid)
|
||||||
|
for page in invalid_pages:
|
||||||
|
yield page
|
||||||
|
for page in valid_pages:
|
||||||
|
yield page
|
||||||
|
|
||||||
|
def formulate_prompt(self, user_prompt: str) -> str:
|
||||||
|
"""Formulate a prompt for the VLM."""
|
||||||
|
|
||||||
|
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||||
|
return user_prompt
|
||||||
|
|
||||||
|
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||||
|
_log.debug("Using specialized prompt for Phi-4")
|
||||||
|
# Note: This might need adjustment for VLLM vs transformers
|
||||||
|
user_prompt_prefix = "<|user|>"
|
||||||
|
assistant_prompt = "<|assistant|>"
|
||||||
|
prompt_suffix = "<|end|>"
|
||||||
|
|
||||||
|
prompt = f"{user_prompt_prefix}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||||
|
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "This is a page from a document.",
|
||||||
|
},
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": user_prompt},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
prompt = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_images(
|
||||||
|
self,
|
||||||
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
|
prompt: Union[str, list[str]],
|
||||||
|
) -> Iterable[VlmPrediction]:
|
||||||
|
"""Process raw images without page metadata in a single batched inference call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_batch: Iterable of PIL Images or numpy arrays
|
||||||
|
prompt: Either:
|
||||||
|
- str: Single prompt used for all images
|
||||||
|
- list[str]: List of prompts (one per image, must match image count)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt list length doesn't match image count.
|
||||||
|
"""
|
||||||
|
pil_images: list[Image] = []
|
||||||
|
|
||||||
|
for img in image_batch:
|
||||||
|
# Convert numpy array to PIL Image if needed
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
if img.ndim == 3 and img.shape[2] in [3, 4]:
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||||
|
elif img.ndim == 2:
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||||
|
else:
|
||||||
|
pil_img = img
|
||||||
|
|
||||||
|
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
||||||
|
if pil_img.mode != "RGB":
|
||||||
|
pil_img = pil_img.convert("RGB")
|
||||||
|
|
||||||
|
pil_images.append(pil_img)
|
||||||
|
|
||||||
|
if len(pil_images) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle prompt parameter
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
# Single prompt for all images
|
||||||
|
user_prompts = [prompt] * len(pil_images)
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
# List of prompts (one per image)
|
||||||
|
if len(prompt) != len(pil_images):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||||
|
)
|
||||||
|
user_prompts = prompt
|
||||||
|
else:
|
||||||
|
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
||||||
|
|
||||||
|
# Format prompts individually
|
||||||
|
prompts: list[str] = [
|
||||||
|
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
# Prepare VLLM inputs
|
||||||
|
llm_inputs = []
|
||||||
|
for prompt, image in zip(prompts, pil_images):
|
||||||
|
llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Logging tokens count for the first sample as a representative metric
|
||||||
|
if len(outputs) > 0:
|
||||||
|
num_tokens = len(outputs[0].outputs[0].token_ids)
|
||||||
|
_log.debug(
|
||||||
|
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
yield VlmPrediction(
|
||||||
|
text=output.outputs[0].text, generation_time=generation_time
|
||||||
|
)
|
||||||
@@ -693,6 +693,17 @@ class ThreadedMultiStageVlmPipeline(BasePipeline):
|
|||||||
accelerator_options=self.pipeline_options.accelerator_options,
|
accelerator_options=self.pipeline_options.accelerator_options,
|
||||||
vlm_options=vlm_options,
|
vlm_options=vlm_options,
|
||||||
)
|
)
|
||||||
|
elif vlm_options.inference_framework == InferenceFramework.VLLM:
|
||||||
|
from docling.models.vlm_models_inline.vllm_model import (
|
||||||
|
VllmVlmModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = VllmVlmModel(
|
||||||
|
enabled=True,
|
||||||
|
artifacts_path=art_path,
|
||||||
|
accelerator_options=self.pipeline_options.accelerator_options,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported inference framework: {vlm_options.inference_framework}"
|
f"Unsupported inference framework: {vlm_options.inference_framework}"
|
||||||
|
|||||||
@@ -103,6 +103,17 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
vlm_options=vlm_options,
|
vlm_options=vlm_options,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
elif vlm_options.inference_framework == InferenceFramework.VLLM:
|
||||||
|
from docling.models.vlm_models_inline.vllm_model import VllmVlmModel
|
||||||
|
|
||||||
|
self.build_pipe = [
|
||||||
|
VllmVlmModel(
|
||||||
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
),
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||||
|
|||||||
@@ -255,6 +255,7 @@ module = [
|
|||||||
"huggingface_hub.*",
|
"huggingface_hub.*",
|
||||||
"transformers.*",
|
"transformers.*",
|
||||||
"pylatexenc.*",
|
"pylatexenc.*",
|
||||||
|
"vllm.*",
|
||||||
]
|
]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user