fix: use only backend for picture classifier (#1904)

use backend for picture classifier

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-07-07 16:23:16 +02:00 committed by GitHub
parent dd8fde7f19
commit edd4356aac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 14 deletions

View File

@ -14,7 +14,8 @@ from PIL import Image
from pydantic import BaseModel from pydantic import BaseModel
from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.models.base_model import BaseEnrichmentModel from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.models.utils.hf_model_download import download_hf_model from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device from docling.utils.accelerator_utils import decide_device
@ -32,7 +33,7 @@ class DocumentPictureClassifierOptions(BaseModel):
kind: Literal["document_picture_classifier"] = "document_picture_classifier" kind: Literal["document_picture_classifier"] = "document_picture_classifier"
class DocumentPictureClassifier(BaseEnrichmentModel): class DocumentPictureClassifier(BaseItemAndImageEnrichmentModel):
""" """
A model for classifying pictures in documents. A model for classifying pictures in documents.
@ -135,7 +136,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
def __call__( def __call__(
self, self,
doc: DoclingDocument, doc: DoclingDocument,
element_batch: Iterable[NodeItem], element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]: ) -> Iterable[NodeItem]:
""" """
Processes a batch of elements and enriches them with classification predictions. Processes a batch of elements and enriches them with classification predictions.
@ -144,7 +145,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
---------- ----------
doc : DoclingDocument doc : DoclingDocument
The document containing the elements to be processed. The document containing the elements to be processed.
element_batch : Iterable[NodeItem] element_batch : Iterable[ItemAndImageEnrichmentElement]
A batch of pictures to classify. A batch of pictures to classify.
Returns Returns
@ -155,22 +156,20 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
""" """
if not self.enabled: if not self.enabled:
for element in element_batch: for element in element_batch:
yield element yield element.item
return return
images: List[Union[Image.Image, np.ndarray]] = [] images: List[Union[Image.Image, np.ndarray]] = []
elements: List[PictureItem] = [] elements: List[PictureItem] = []
for el in element_batch: for el in element_batch:
assert isinstance(el, PictureItem) assert isinstance(el.item, PictureItem)
elements.append(el) elements.append(el.item)
img = el.get_image(doc) images.append(el.image)
assert img is not None
images.append(img)
outputs = self.document_picture_classifier.predict(images) outputs = self.document_picture_classifier.predict(images)
for element, output in zip(elements, outputs): for item, output in zip(elements, outputs):
element.annotations.append( item.annotations.append(
PictureClassificationData( PictureClassificationData(
provenance="DocumentPictureClassifier", provenance="DocumentPictureClassifier",
predicted_classes=[ predicted_classes=[
@ -183,4 +182,4 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
) )
) )
yield element yield item

View File

@ -129,6 +129,7 @@ class StandardPdfPipeline(PaginatedPipeline):
if ( if (
self.pipeline_options.do_formula_enrichment self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment or self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_picture_classification
or self.pipeline_options.do_picture_description or self.pipeline_options.do_picture_description
): ):
self.keep_backend = True self.keep_backend = True

View File

@ -17,8 +17,9 @@ def get_converter():
pipeline_options.do_table_structure = False pipeline_options.do_table_structure = False
pipeline_options.do_code_enrichment = False pipeline_options.do_code_enrichment = False
pipeline_options.do_formula_enrichment = False pipeline_options.do_formula_enrichment = False
pipeline_options.generate_picture_images = False
pipeline_options.generate_page_images = False
pipeline_options.do_picture_classification = True pipeline_options.do_picture_classification = True
pipeline_options.generate_picture_images = True
pipeline_options.images_scale = 2 pipeline_options.images_scale = 2
converter = DocumentConverter( converter = DocumentConverter(