mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
refactoring the VLM part
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
ee01e3cff0
commit
96862bd326
@ -317,8 +317,7 @@ smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
|
||||
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
||||
# prompt="OCR the full page to markdown.",
|
||||
prompt="OCR this image.",
|
||||
prompt="OCR the full page to markdown.",
|
||||
response_format=ResponseFormat.MARKDOWN,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
|
||||
)
|
||||
|
@ -18,6 +18,8 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceVlmModel(BasePageModel):
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@ -89,7 +91,8 @@ class HuggingFaceVlmModel(BasePageModel):
|
||||
),
|
||||
# trust_remote_code=True,
|
||||
) # .to(self.device)
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
repo_id: str,
|
||||
@ -111,6 +114,7 @@ class HuggingFaceVlmModel(BasePageModel):
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
"""
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@ -185,3 +189,4 @@ class HuggingFaceVlmModel(BasePageModel):
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||
|
||||
yield page
|
||||
"""
|
||||
|
@ -42,9 +42,9 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
)
|
||||
|
||||
self.device = decide_device(accelerator_options.device)
|
||||
self.device = "cpu" # device
|
||||
self.device = "cpu" # FIXME
|
||||
|
||||
_log.debug(f"Available device for HuggingFace VLM: {self.device}")
|
||||
_log.debug(f"Available device for VLM: {self.device}")
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
# PARAMETERS:
|
||||
@ -154,6 +154,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
).to(self.device)
|
||||
|
||||
# Generate response
|
||||
start_time = time.time()
|
||||
generate_ids = self.vlm_model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=128,
|
||||
@ -162,13 +163,19 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
)
|
||||
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
|
||||
|
||||
# num_tokens = len(generate_ids[0])
|
||||
num_tokens = len(generate_ids[0])
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
generate_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)[0]
|
||||
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
)
|
||||
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
|
@ -109,10 +109,10 @@ class VlmPipeline(PaginatedPipeline):
|
||||
]
|
||||
else:
|
||||
_log.warning(
|
||||
"falling back to HuggingFaceVlmModel (AutoModelForVision2Seq) pipeline"
|
||||
"falling back to HuggingFaceVlmModel_AutoModelForVision2Seq pipeline"
|
||||
)
|
||||
self.build_pipe = [
|
||||
HuggingFaceVlmModel(
|
||||
HuggingFaceVlmModel_AutoModelForVision2Seq(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
|
Loading…
Reference in New Issue
Block a user