reformatted the code

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-05-16 16:31:11 +02:00
parent d5b6c871cf
commit 0c7c7c11c2
9 changed files with 96 additions and 85 deletions

View File

@ -186,9 +186,9 @@ class DocumentConverter:
Tuple[Type[BasePipeline], str], BasePipeline Tuple[Type[BasePipeline], str], BasePipeline
] = {} ] = {}
def _get_initialized_pipelines(self) -> dict[ def _get_initialized_pipelines(
tuple[Type[BasePipeline], str], BasePipeline self,
]: ) -> dict[tuple[Type[BasePipeline], str], BasePipeline]:
return self.initialized_pipelines return self.initialized_pipelines
def _get_pipeline_options_hash(self, pipeline_options: PipelineOptions) -> str: def _get_pipeline_options_hash(self, pipeline_options: PipelineOptions) -> str:

View File

@ -6,7 +6,6 @@ _log = logging.getLogger(__name__)
class HuggingFaceVlmModel: class HuggingFaceVlmModel:
@staticmethod @staticmethod
def map_device_to_cpu_if_mlx(device: str) -> str: def map_device_to_cpu_if_mlx(device: str) -> str:
if device == "mps": if device == "mps":

View File

@ -76,8 +76,6 @@ class HuggingFaceMlxModel(BasePageModel):
assert page.size is not None assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale) hi_res_image = page.get_image(scale=self.vlm_options.scale)
hi_res_image.save("./scratch/page.png")
if hi_res_image is not None: if hi_res_image is not None:
im_width, im_height = hi_res_image.size im_width, im_height = hi_res_image.size
@ -128,7 +126,9 @@ class HuggingFaceMlxModel(BasePageModel):
) )
) )
else: else:
_log.warning(f"incompatible shape for logprobs: {token.logprobs.shape}") _log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text output += token.text
if "</doctag>" in token.text: if "</doctag>" in token.text:

View File

@ -42,7 +42,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
) )
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = HuggingFaceVlmMode.map_device_to_cpu_if_mlx(self.device) self.device = HuggingFaceVlmModel.map_device_to_cpu_if_mlx(self.device)
_log.debug(f"Available device for VLM: {self.device}") _log.debug(f"Available device for VLM: {self.device}")
self.use_cache = vlm_options.use_kv_cache self.use_cache = vlm_options.use_kv_cache
@ -153,7 +153,9 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
_log.debug( _log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
) )
page.predictions.vlm_response = VlmPrediction(text=response, generation_time=generation_time) page.predictions.vlm_response = VlmPrediction(
text=response, generation_time=generation_time
)
yield page yield page

View File

@ -39,7 +39,7 @@ class HuggingFaceVlmModel_AutoModelForVision2Seq(BasePageModel):
) )
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = HuggingFaceVlmMode.map_device_to_cpu_if_mlx(self.device) self.device = HuggingFaceVlmModel.map_device_to_cpu_if_mlx(self.device)
_log.debug(f"Available device for HuggingFace VLM: {self.device}") _log.debug(f"Available device for HuggingFace VLM: {self.device}")

View File

@ -39,7 +39,7 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
) )
self.device = decide_device(accelerator_options.device) self.device = decide_device(accelerator_options.device)
self.device = HuggingFaceVlmMode.map_device_to_cpu_if_mlx(self.device) self.device = HuggingFaceVlmModel.map_device_to_cpu_if_mlx(self.device)
self.use_cache = vlm_options.use_kv_cache self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens self.max_new_tokens = vlm_options.max_new_tokens

View File

@ -1,11 +1,23 @@
import re
import logging import logging
import re
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union, cast from typing import List, Optional, Union, cast
# from docling_core.types import DoclingDocument # from docling_core.types import DoclingDocument
from docling_core.types.doc import BoundingBox, DocItem, ImageRef, PictureItem, TextItem from docling_core.types.doc import (
BoundingBox,
DocItem,
DoclingDocument,
ImageRef,
PictureItem,
ProvenanceItem,
TextItem,
)
from docling_core.types.doc.base import (
BoundingBox,
Size,
)
from docling_core.types.doc.document import DocTagsDocument from docling_core.types.doc.document import DocTagsDocument
from PIL import Image as PILImage from PIL import Image as PILImage
@ -20,14 +32,6 @@ from docling.datamodel.pipeline_model_specializations import (
InferenceFramework, InferenceFramework,
ResponseFormat, ResponseFormat,
) )
from docling_core.types.doc.base import (
Size,
BoundingBox,
)
from docling_core.types.doc import (
ProvenanceItem,
DoclingDocument
)
from docling.datamodel.pipeline_options import ( from docling.datamodel.pipeline_options import (
VlmPipelineOptions, VlmPipelineOptions,
) )
@ -168,6 +172,7 @@ class VlmPipeline(PaginatedPipeline):
self.pipeline_options.vlm_options.response_format self.pipeline_options.vlm_options.response_format
== ResponseFormat.DOCTAGS == ResponseFormat.DOCTAGS
): ):
"""
doctags_list = [] doctags_list = []
image_list = [] image_list = []
for page in conv_res.pages: for page in conv_res.pages:
@ -207,6 +212,9 @@ class VlmPipeline(PaginatedPipeline):
txt = self.extract_text_from_backend(page, crop_bbox) txt = self.extract_text_from_backend(page, crop_bbox)
element.text = txt element.text = txt
element.orig = txt element.orig = txt
"""
conv_res.document = self._turn_dt_into_doc(conv_res)
elif ( elif (
self.pipeline_options.vlm_options.response_format self.pipeline_options.vlm_options.response_format
== ResponseFormat.MARKDOWN == ResponseFormat.MARKDOWN
@ -271,21 +279,18 @@ class VlmPipeline(PaginatedPipeline):
if self.force_backend_text: if self.force_backend_text:
scale = self.pipeline_options.images_scale scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items(): for element, _level in conv_res.document.iterate_items():
if (not isinstance(element, TextItem) if not isinstance(element, TextItem) or len(element.prov) == 0:
or len(element.prov) == 0
):
continue continue
crop_bbox = ( crop_bbox = (
element.prov[0] element.prov[0]
.bbox.scaled(scale=scale) .bbox.scaled(scale=scale)
.to_top_left_origin( .to_top_left_origin(page_height=page.size.height * scale)
page_height=page.size.height * scale
)
) )
txt = self.extract_text_from_backend(page, crop_bbox) txt = self.extract_text_from_backend(page, crop_bbox)
element.text = txt element.text = txt
element.orig = txt element.orig = txt
return conv_res.document
""" """
def _turn_md_into_doc(self, conv_res): def _turn_md_into_doc(self, conv_res):
@ -308,7 +313,6 @@ class VlmPipeline(PaginatedPipeline):
""" """
def _turn_md_into_doc(self, conv_res): def _turn_md_into_doc(self, conv_res):
def _extract_markdown_code(text): def _extract_markdown_code(text):
""" """
Extracts text from markdown code blocks (enclosed in triple backticks). Extracts text from markdown code blocks (enclosed in triple backticks).
@ -322,10 +326,7 @@ class VlmPipeline(PaginatedPipeline):
""" """
# Regex pattern to match content between triple backticks # Regex pattern to match content between triple backticks
# This handles multiline content and optional language specifier # This handles multiline content and optional language specifier
pattern = r'^```(?:\w*\n)?(.*?)```(\n)*$' pattern = r"^```(?:\w*\n)?(.*?)```(\n)*$"
# Search for matches with DOTALL flag to match across multiple lines
matches = re.findall(pattern, text, re.DOTALL)
# Search with DOTALL flag to match across multiple lines # Search with DOTALL flag to match across multiple lines
mtch = re.search(pattern, text, re.DOTALL) mtch = re.search(pattern, text, re.DOTALL)
@ -338,7 +339,6 @@ class VlmPipeline(PaginatedPipeline):
return text return text
for pg_idx, page in enumerate(conv_res.pages): for pg_idx, page in enumerate(conv_res.pages):
page_no = pg_idx + 1 # FIXME: might be incorrect page_no = pg_idx + 1 # FIXME: might be incorrect
predicted_text = "" predicted_text = ""
@ -370,14 +370,18 @@ class VlmPipeline(PaginatedPipeline):
conv_res.document.add_page( conv_res.document.add_page(
page_no=page_no, page_no=page_no,
size=Size(width=pg_width, height=pg_height), size=Size(width=pg_width, height=pg_height),
image=ImageRef.from_pil(image=page.image, dpi=72) if page.image else None, image=ImageRef.from_pil(image=page.image, dpi=72)
if page.image
else None,
) )
for item, level in page_doc.iterate_items(): for item, level in page_doc.iterate_items():
item.prov = [ item.prov = [
ProvenanceItem(page_no=pg_idx+1, ProvenanceItem(
page_no=pg_idx + 1,
bbox=BoundingBox(t=0.0, b=0.0, l=0.0, r=0.0), bbox=BoundingBox(t=0.0, b=0.0, l=0.0, r=0.0),
charspan=[0,0]) charspan=[0, 0],
)
] ]
conv_res.document.append_child_item(child=item) conv_res.document.append_child_item(child=item)
print(item) print(item)

View File

@ -4,6 +4,7 @@ from pathlib import Path
from docling_core.types.doc import DocItemLabel, ImageRefMode from docling_core.types.doc import DocItemLabel, ImageRefMode
from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS
from tabulate import tabulate
from docling.datamodel.base_models import InputFormat from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_model_specializations import ( from docling.datamodel.pipeline_model_specializations import (
@ -25,8 +26,6 @@ from docling.datamodel.pipeline_options import (
from docling.document_converter import DocumentConverter, PdfFormatOption from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline from docling.pipeline.vlm_pipeline import VlmPipeline
from tabulate import tabulate
## Use experimental VlmPipeline ## Use experimental VlmPipeline
pipeline_options = VlmPipelineOptions() pipeline_options = VlmPipelineOptions()
# If force_backend_text = True, text from backend will be used instead of generated text # If force_backend_text = True, text from backend will be used instead of generated text
@ -101,6 +100,7 @@ qwen_vlm_conversion_options = HuggingFaceVlmOptions(
pipeline_options.vlm_options = qwen_vlm_conversion_options pipeline_options.vlm_options = qwen_vlm_conversion_options
""" """
def convert(sources: list[Path], converter): def convert(sources: list[Path], converter):
for source in sources: for source in sources:
# start_time = time.time() # start_time = time.time()
@ -161,10 +161,16 @@ def convert(sources: list[Path], converter):
print("====================================================") print("====================================================")
# return [source, f"{out_path / fname}.html", model_id, framework, inference_time, ] # return [source, f"{out_path / fname}.html", model_id, framework, inference_time, ]
return [source, model_id, framework, pg_num, inference_time, ] return [
source,
model_id,
framework,
pg_num,
inference_time,
]
if __name__ == "__main__": if __name__ == "__main__":
sources = [ sources = [
# "tests/data/2305.03393v1-pg9-img.png", # "tests/data/2305.03393v1-pg9-img.png",
"tests/data/pdf/2305.03393v1-pg9.pdf", "tests/data/pdf/2305.03393v1-pg9.pdf",
@ -187,7 +193,7 @@ if __name__ == "__main__":
rows = [] rows = []
for vlm_options in [ for vlm_options in [
# smoldocling_vlm_conversion_options, \ # smoldocling_vlm_conversion_options, \
smoldocling_vlm_mlx_conversion_options, \ smoldocling_vlm_mlx_conversion_options,
# granite_vision_vlm_conversion_options, \ # granite_vision_vlm_conversion_options, \
# phi_vlm_conversion_options, \ # phi_vlm_conversion_options, \
# qwen25_vl_3b_vlm_mlx_conversion_options, \ # qwen25_vl_3b_vlm_mlx_conversion_options, \