refactoring the VLM part

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-13 10:01:37 +02:00
parent ee01e3cff0
commit 96862bd326
4 changed files with 19 additions and 8 deletions

View File

@ -317,8 +317,7 @@ smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
# prompt="OCR the full page to markdown.",
prompt="OCR this image.",
prompt="OCR the full page to markdown.",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
)

View File

@ -18,6 +18,8 @@ _log = logging.getLogger(__name__)
class HuggingFaceVlmModel(BasePageModel):
"""
def __init__(
self,
enabled: bool,
@ -89,6 +91,7 @@ class HuggingFaceVlmModel(BasePageModel):
),
# trust_remote_code=True,
) # .to(self.device)
"""
@staticmethod
def download_models(
@ -111,6 +114,7 @@ class HuggingFaceVlmModel(BasePageModel):
return Path(download_path)
"""
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -185,3 +189,4 @@ class HuggingFaceVlmModel(BasePageModel):
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page
"""

View File

@ -42,9 +42,9 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
)
self.device = decide_device(accelerator_options.device)
self.device = "cpu" # device
self.device = "cpu" # FIXME
_log.debug(f"Available device for HuggingFace VLM: {self.device}")
_log.debug(f"Available device for VLM: {self.device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
@ -154,6 +154,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
).to(self.device)
# Generate response
start_time = time.time()
generate_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=128,
@ -162,13 +163,19 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
)
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
# num_tokens = len(generate_ids[0])
num_tokens = len(generate_ids[0])
generation_time = time.time() - start_time
response = self.processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")

View File

@ -109,10 +109,10 @@ class VlmPipeline(PaginatedPipeline):
]
else:
_log.warning(
"falling back to HuggingFaceVlmModel (AutoModelForVision2Seq) pipeline"
"falling back to HuggingFaceVlmModel_AutoModelForVision2Seq pipeline"
)
self.build_pipe = [
HuggingFaceVlmModel(
HuggingFaceVlmModel_AutoModelForVision2Seq(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,