add generation options

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-02-06 11:12:28 +01:00
parent 11e27930c4
commit 06342a5a28
2 changed files with 5 additions and 3 deletions

View File

@ -211,7 +211,8 @@ class PicDescVlmOptions(PicDescBaseOptions):
repo_id: str repo_id: str
prompt: str = "Describe this image in a few sentences." prompt: str = "Describe this image in a few sentences."
max_new_tokens: int = 200 # Config from here https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
generation_config: Dict[str, Any] = dict(max_new_tokens=200, do_sample=False)
# class PicDescSmolVlmOptions(PicDescVlmOptions): # class PicDescSmolVlmOptions(PicDescVlmOptions):

View File

@ -69,6 +69,7 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
return Path(download_path) return Path(download_path)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
from transformers import GenerationConfig
# Create input messages # Create input messages
messages = [ messages = [
@ -81,7 +82,6 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
}, },
] ]
# TODO: set seed for reproducibility
# TODO: do batch generation # TODO: do batch generation
for image in images: for image in images:
@ -94,7 +94,8 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
# Generate outputs # Generate outputs
generated_ids = self.model.generate( generated_ids = self.model.generate(
**inputs, max_new_tokens=self.options.max_new_tokens **inputs,
generation_config=GenerationConfig(**self.options.generation_config),
) )
generated_texts = self.processor.batch_decode( generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :], generated_ids[:, inputs["input_ids"].shape[1] :],