mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-02 15:32:30 +00:00
add generation options
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
11e27930c4
commit
06342a5a28
@ -211,7 +211,8 @@ class PicDescVlmOptions(PicDescBaseOptions):
|
|||||||
|
|
||||||
repo_id: str
|
repo_id: str
|
||||||
prompt: str = "Describe this image in a few sentences."
|
prompt: str = "Describe this image in a few sentences."
|
||||||
max_new_tokens: int = 200
|
# Config from here https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
|
||||||
|
generation_config: Dict[str, Any] = dict(max_new_tokens=200, do_sample=False)
|
||||||
|
|
||||||
|
|
||||||
# class PicDescSmolVlmOptions(PicDescVlmOptions):
|
# class PicDescSmolVlmOptions(PicDescVlmOptions):
|
||||||
|
@ -69,6 +69,7 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
|||||||
return Path(download_path)
|
return Path(download_path)
|
||||||
|
|
||||||
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
# Create input messages
|
# Create input messages
|
||||||
messages = [
|
messages = [
|
||||||
@ -81,7 +82,6 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: set seed for reproducibility
|
|
||||||
# TODO: do batch generation
|
# TODO: do batch generation
|
||||||
|
|
||||||
for image in images:
|
for image in images:
|
||||||
@ -94,7 +94,8 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
|||||||
|
|
||||||
# Generate outputs
|
# Generate outputs
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
**inputs, max_new_tokens=self.options.max_new_tokens
|
**inputs,
|
||||||
|
generation_config=GenerationConfig(**self.options.generation_config),
|
||||||
)
|
)
|
||||||
generated_texts = self.processor.batch_decode(
|
generated_texts = self.processor.batch_decode(
|
||||||
generated_ids[:, inputs["input_ids"].shape[1] :],
|
generated_ids[:, inputs["input_ids"].shape[1] :],
|
||||||
|
Loading…
Reference in New Issue
Block a user