mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
refactoring the VLM part
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
ee01e3cff0
commit
96862bd326
@ -317,8 +317,7 @@ smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
|
|||||||
|
|
||||||
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
||||||
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
|
||||||
# prompt="OCR the full page to markdown.",
|
prompt="OCR the full page to markdown.",
|
||||||
prompt="OCR this image.",
|
|
||||||
response_format=ResponseFormat.MARKDOWN,
|
response_format=ResponseFormat.MARKDOWN,
|
||||||
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
|
inference_framework=InferenceFramework.TRANSFORMERS_AutoModelForVision2Seq,
|
||||||
)
|
)
|
||||||
|
@ -18,6 +18,8 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class HuggingFaceVlmModel(BasePageModel):
|
class HuggingFaceVlmModel(BasePageModel):
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
@ -89,6 +91,7 @@ class HuggingFaceVlmModel(BasePageModel):
|
|||||||
),
|
),
|
||||||
# trust_remote_code=True,
|
# trust_remote_code=True,
|
||||||
) # .to(self.device)
|
) # .to(self.device)
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download_models(
|
def download_models(
|
||||||
@ -111,6 +114,7 @@ class HuggingFaceVlmModel(BasePageModel):
|
|||||||
|
|
||||||
return Path(download_path)
|
return Path(download_path)
|
||||||
|
|
||||||
|
"""
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
@ -185,3 +189,4 @@ class HuggingFaceVlmModel(BasePageModel):
|
|||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
"""
|
||||||
|
@ -42,9 +42,9 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.device = decide_device(accelerator_options.device)
|
self.device = decide_device(accelerator_options.device)
|
||||||
self.device = "cpu" # device
|
self.device = "cpu" # FIXME
|
||||||
|
|
||||||
_log.debug(f"Available device for HuggingFace VLM: {self.device}")
|
_log.debug(f"Available device for VLM: {self.device}")
|
||||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
|
|
||||||
# PARAMETERS:
|
# PARAMETERS:
|
||||||
@ -154,6 +154,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
# Generate response
|
# Generate response
|
||||||
|
start_time = time.time()
|
||||||
generate_ids = self.vlm_model.generate(
|
generate_ids = self.vlm_model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
max_new_tokens=128,
|
max_new_tokens=128,
|
||||||
@ -162,13 +163,19 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
)
|
)
|
||||||
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
|
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
|
||||||
|
|
||||||
# num_tokens = len(generate_ids[0])
|
num_tokens = len(generate_ids[0])
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
response = self.processor.batch_decode(
|
response = self.processor.batch_decode(
|
||||||
generate_ids,
|
generate_ids,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
_log.debug(
|
||||||
|
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
# inference_time = time.time() - start_time
|
# inference_time = time.time() - start_time
|
||||||
# tokens_per_second = num_tokens / generation_time
|
# tokens_per_second = num_tokens / generation_time
|
||||||
# print("")
|
# print("")
|
||||||
|
@ -109,10 +109,10 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
_log.warning(
|
_log.warning(
|
||||||
"falling back to HuggingFaceVlmModel (AutoModelForVision2Seq) pipeline"
|
"falling back to HuggingFaceVlmModel_AutoModelForVision2Seq pipeline"
|
||||||
)
|
)
|
||||||
self.build_pipe = [
|
self.build_pipe = [
|
||||||
HuggingFaceVlmModel(
|
HuggingFaceVlmModel_AutoModelForVision2Seq(
|
||||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
artifacts_path=artifacts_path,
|
artifacts_path=artifacts_path,
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
Loading…
Reference in New Issue
Block a user