mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-24 19:14:23 +00:00
fix: use only backend for picture classifier (#1904)
use backend for picture classifier Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
dd8fde7f19
commit
edd4356aac
@ -14,7 +14,8 @@ from PIL import Image
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||||
from docling.models.base_model import BaseEnrichmentModel
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
||||||
|
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
||||||
from docling.models.utils.hf_model_download import download_hf_model
|
from docling.models.utils.hf_model_download import download_hf_model
|
||||||
from docling.utils.accelerator_utils import decide_device
|
from docling.utils.accelerator_utils import decide_device
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ class DocumentPictureClassifierOptions(BaseModel):
|
|||||||
kind: Literal["document_picture_classifier"] = "document_picture_classifier"
|
kind: Literal["document_picture_classifier"] = "document_picture_classifier"
|
||||||
|
|
||||||
|
|
||||||
class DocumentPictureClassifier(BaseEnrichmentModel):
|
class DocumentPictureClassifier(BaseItemAndImageEnrichmentModel):
|
||||||
"""
|
"""
|
||||||
A model for classifying pictures in documents.
|
A model for classifying pictures in documents.
|
||||||
|
|
||||||
@ -135,7 +136,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
doc: DoclingDocument,
|
doc: DoclingDocument,
|
||||||
element_batch: Iterable[NodeItem],
|
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
||||||
) -> Iterable[NodeItem]:
|
) -> Iterable[NodeItem]:
|
||||||
"""
|
"""
|
||||||
Processes a batch of elements and enriches them with classification predictions.
|
Processes a batch of elements and enriches them with classification predictions.
|
||||||
@ -144,7 +145,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|||||||
----------
|
----------
|
||||||
doc : DoclingDocument
|
doc : DoclingDocument
|
||||||
The document containing the elements to be processed.
|
The document containing the elements to be processed.
|
||||||
element_batch : Iterable[NodeItem]
|
element_batch : Iterable[ItemAndImageEnrichmentElement]
|
||||||
A batch of pictures to classify.
|
A batch of pictures to classify.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -155,22 +156,20 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|||||||
"""
|
"""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
for element in element_batch:
|
for element in element_batch:
|
||||||
yield element
|
yield element.item
|
||||||
return
|
return
|
||||||
|
|
||||||
images: List[Union[Image.Image, np.ndarray]] = []
|
images: List[Union[Image.Image, np.ndarray]] = []
|
||||||
elements: List[PictureItem] = []
|
elements: List[PictureItem] = []
|
||||||
for el in element_batch:
|
for el in element_batch:
|
||||||
assert isinstance(el, PictureItem)
|
assert isinstance(el.item, PictureItem)
|
||||||
elements.append(el)
|
elements.append(el.item)
|
||||||
img = el.get_image(doc)
|
images.append(el.image)
|
||||||
assert img is not None
|
|
||||||
images.append(img)
|
|
||||||
|
|
||||||
outputs = self.document_picture_classifier.predict(images)
|
outputs = self.document_picture_classifier.predict(images)
|
||||||
|
|
||||||
for element, output in zip(elements, outputs):
|
for item, output in zip(elements, outputs):
|
||||||
element.annotations.append(
|
item.annotations.append(
|
||||||
PictureClassificationData(
|
PictureClassificationData(
|
||||||
provenance="DocumentPictureClassifier",
|
provenance="DocumentPictureClassifier",
|
||||||
predicted_classes=[
|
predicted_classes=[
|
||||||
@ -183,4 +182,4 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield element
|
yield item
|
||||||
|
@ -129,6 +129,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
if (
|
if (
|
||||||
self.pipeline_options.do_formula_enrichment
|
self.pipeline_options.do_formula_enrichment
|
||||||
or self.pipeline_options.do_code_enrichment
|
or self.pipeline_options.do_code_enrichment
|
||||||
|
or self.pipeline_options.do_picture_classification
|
||||||
or self.pipeline_options.do_picture_description
|
or self.pipeline_options.do_picture_description
|
||||||
):
|
):
|
||||||
self.keep_backend = True
|
self.keep_backend = True
|
||||||
|
@ -17,8 +17,9 @@ def get_converter():
|
|||||||
pipeline_options.do_table_structure = False
|
pipeline_options.do_table_structure = False
|
||||||
pipeline_options.do_code_enrichment = False
|
pipeline_options.do_code_enrichment = False
|
||||||
pipeline_options.do_formula_enrichment = False
|
pipeline_options.do_formula_enrichment = False
|
||||||
|
pipeline_options.generate_picture_images = False
|
||||||
|
pipeline_options.generate_page_images = False
|
||||||
pipeline_options.do_picture_classification = True
|
pipeline_options.do_picture_classification = True
|
||||||
pipeline_options.generate_picture_images = True
|
|
||||||
pipeline_options.images_scale = 2
|
pipeline_options.images_scale = 2
|
||||||
|
|
||||||
converter = DocumentConverter(
|
converter = DocumentConverter(
|
||||||
|
Loading…
Reference in New Issue
Block a user