mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-16 16:48:21 +00:00
feat: Code and equation model for PDF and code blocks in markdown (#752)
* propagated changes for new CodeItem class Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * Rebased branch on latest main. changes for CodeItem Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed unused files Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * chore: update lockfile Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * pin latest docling-core Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * update docling-core pinning Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * pin docling-core Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use new add_code in backends and update typing in MD backend Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * added if statement for backend Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed unused import Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed print statements Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * gt for new pdf Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * Update docling/pipeline/standard_pdf_pipeline.py Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Signed-off-by: Matteo <43417658+Matteo-Omenetti@users.noreply.github.com> * fixed doc comment of __call__ function of code_formula_model Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * fix artifacts_path type Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move imports Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move expansion_factor to base class Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Matteo <43417658+Matteo-Omenetti@users.noreply.github.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, Iterable, Optional
|
||||
|
||||
from docling_core.types.doc import DoclingDocument, NodeItem, TextItem
|
||||
from docling_core.types.doc import BoundingBox, DoclingDocument, NodeItem, TextItem
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
||||
@@ -53,6 +53,7 @@ class BaseItemAndImageEnrichmentModel(
|
||||
):
|
||||
|
||||
images_scale: float
|
||||
expansion_factor: float = 0.0
|
||||
|
||||
def prepare_element(
|
||||
self, conv_res: ConversionResult, element: NodeItem
|
||||
@@ -62,8 +63,22 @@ class BaseItemAndImageEnrichmentModel(
|
||||
|
||||
assert isinstance(element, TextItem)
|
||||
element_prov = element.prov[0]
|
||||
|
||||
bbox = element_prov.bbox
|
||||
width = bbox.r - bbox.l
|
||||
height = bbox.t - bbox.b
|
||||
|
||||
# TODO: move to a utility in the BoundingBox class
|
||||
expanded_bbox = BoundingBox(
|
||||
l=bbox.l - width * self.expansion_factor,
|
||||
t=bbox.t + height * self.expansion_factor,
|
||||
r=bbox.r + width * self.expansion_factor,
|
||||
b=bbox.b - height * self.expansion_factor,
|
||||
coord_origin=bbox.coord_origin,
|
||||
)
|
||||
|
||||
page_ix = element_prov.page_no - 1
|
||||
cropped_image = conv_res.pages[page_ix].get_image(
|
||||
scale=self.images_scale, cropbox=element_prov.bbox
|
||||
scale=self.images_scale, cropbox=expanded_bbox
|
||||
)
|
||||
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)
|
||||
|
||||
245
docling/models/code_formula_model.py
Normal file
245
docling/models/code_formula_model.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from docling_core.types.doc import (
|
||||
CodeItem,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
NodeItem,
|
||||
TextItem,
|
||||
)
|
||||
from docling_core.types.doc.labels import CodeLanguageLabel
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
|
||||
class CodeFormulaModelOptions(BaseModel):
|
||||
"""
|
||||
Configuration options for the CodeFormulaModel.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
kind : str
|
||||
Type of the model. Fixed value "code_formula".
|
||||
do_code_enrichment : bool
|
||||
True if code enrichment is enabled, False otherwise.
|
||||
do_formula_enrichment : bool
|
||||
True if formula enrichment is enabled, False otherwise.
|
||||
"""
|
||||
|
||||
kind: Literal["code_formula"] = "code_formula"
|
||||
do_code_enrichment: bool = True
|
||||
do_formula_enrichment: bool = True
|
||||
|
||||
|
||||
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
"""
|
||||
Model for processing and enriching documents with code and formula predictions.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
enabled : bool
|
||||
True if the model is enabled, False otherwise.
|
||||
options : CodeFormulaModelOptions
|
||||
Configuration options for the CodeFormulaModel.
|
||||
code_formula_model : CodeFormulaPredictor
|
||||
The predictor model for code and formula processing.
|
||||
|
||||
Methods
|
||||
-------
|
||||
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
|
||||
Initializes the CodeFormulaModel with the given configuration options.
|
||||
is_processable(self, doc, element)
|
||||
Determines if a given element in a document can be processed by the model.
|
||||
__call__(self, doc, element_batch)
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
"""
|
||||
|
||||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
||||
expansion_factor = 0.03
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
options: CodeFormulaModelOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
"""
|
||||
Initializes the CodeFormulaModel with the given configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
enabled : bool
|
||||
True if the model is enabled, False otherwise.
|
||||
artifacts_path : Path
|
||||
Path to the directory containing the model artifacts.
|
||||
options : CodeFormulaModelOptions
|
||||
Configuration options for the model.
|
||||
accelerator_options : AcceleratorOptions
|
||||
Options specifying the device and number of threads for acceleration.
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.options = options
|
||||
|
||||
if self.enabled:
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
from docling_ibm_models.code_formula_model.code_formula_predictor import (
|
||||
CodeFormulaPredictor,
|
||||
)
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models_hf()
|
||||
else:
|
||||
artifacts_path = Path(artifacts_path)
|
||||
|
||||
self.code_formula_model = CodeFormulaPredictor(
|
||||
artifacts_path=artifacts_path,
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/CodeFormula",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v1.0.0",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
"""
|
||||
Determines if a given element in a document can be processed by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document being processed.
|
||||
element : NodeItem
|
||||
The element within the document to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the element can be processed, False otherwise.
|
||||
"""
|
||||
return self.enabled and (
|
||||
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
|
||||
or (
|
||||
isinstance(element, TextItem)
|
||||
and element.label == DocItemLabel.FORMULA
|
||||
and self.options.do_formula_enrichment
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
|
||||
"""Extracts a programming language from the beginning of a string.
|
||||
|
||||
This function checks if the input string starts with a pattern of the form
|
||||
``<_some_language_>``. If it does, it extracts the language string and returns
|
||||
a tuple of (remainder, language). Otherwise, it returns the original string
|
||||
and `None`.
|
||||
|
||||
Args:
|
||||
input_string (str): The input string, which may start with ``<_language_>``.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[str]]:
|
||||
A tuple where:
|
||||
- The first element is either:
|
||||
- The remainder of the string (everything after ``<_language_>``),
|
||||
if a match is found; or
|
||||
- The original string, if no match is found.
|
||||
- The second element is the extracted language if a match is found;
|
||||
otherwise, `None`.
|
||||
"""
|
||||
pattern = r"^<_([^>]+)_>\s*(.*)"
|
||||
match = re.match(pattern, input_string, flags=re.DOTALL)
|
||||
if match:
|
||||
language = str(match.group(1)) # the captured programming language
|
||||
remainder = str(match.group(2)) # everything after the <_language_>
|
||||
return remainder, language
|
||||
else:
|
||||
return input_string, None
|
||||
|
||||
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
|
||||
"""
|
||||
Converts a string to a corresponding `CodeLanguageLabel` enum member.
|
||||
|
||||
If the provided string does not match any value in `CodeLanguageLabel`,
|
||||
it defaults to `CodeLanguageLabel.UNKNOWN`.
|
||||
|
||||
Args:
|
||||
value (Optional[str]): The string representation of the code language or None.
|
||||
|
||||
Returns:
|
||||
CodeLanguageLabel: The corresponding enum member if the value is valid,
|
||||
otherwise `CodeLanguageLabel.UNKNOWN`.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return CodeLanguageLabel.UNKNOWN
|
||||
|
||||
try:
|
||||
return CodeLanguageLabel(value)
|
||||
except ValueError:
|
||||
return CodeLanguageLabel.UNKNOWN
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
doc: DoclingDocument,
|
||||
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
||||
) -> Iterable[NodeItem]:
|
||||
"""
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document being processed.
|
||||
element_batch : Iterable[ItemAndImageEnrichmentElement]
|
||||
A batch of elements to be processed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable[Any]
|
||||
An iterable of enriched elements.
|
||||
"""
|
||||
if not self.enabled:
|
||||
for element in element_batch:
|
||||
yield element.item
|
||||
return
|
||||
|
||||
labels: List[str] = []
|
||||
images: List[Image.Image] = []
|
||||
elements: List[TextItem] = []
|
||||
for el in element_batch:
|
||||
assert isinstance(el.item, TextItem)
|
||||
elements.append(el.item)
|
||||
labels.append(el.item.label)
|
||||
images.append(el.image)
|
||||
|
||||
outputs = self.code_formula_model.predict(images, labels)
|
||||
|
||||
for item, output in zip(elements, outputs):
|
||||
if isinstance(item, CodeItem):
|
||||
output, code_language = self._extract_code_language(output)
|
||||
item.code_language = self._get_code_language_enum(code_language)
|
||||
item.text = output
|
||||
|
||||
yield item
|
||||
@@ -40,7 +40,7 @@ class LayoutModel(BasePageModel):
|
||||
DocItemLabel.PAGE_FOOTER,
|
||||
DocItemLabel.CODE,
|
||||
DocItemLabel.LIST_ITEM,
|
||||
# "Formula",
|
||||
DocItemLabel.FORMULA,
|
||||
]
|
||||
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
||||
|
||||
|
||||
@@ -135,31 +135,6 @@ class PageAssembleModel(BasePageModel):
|
||||
)
|
||||
elements.append(fig)
|
||||
body.append(fig)
|
||||
elif cluster.label == LayoutModel.FORMULA_LABEL:
|
||||
equation = None
|
||||
if page.predictions.equations_prediction:
|
||||
equation = page.predictions.equations_prediction.equation_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
if (
|
||||
not equation
|
||||
): # fallback: add empty formula, if it isn't present
|
||||
text = self.sanitize_text(
|
||||
[
|
||||
cell.text.replace("\x02", "-").strip()
|
||||
for cell in cluster.cells
|
||||
if len(cell.text.strip()) > 0
|
||||
]
|
||||
)
|
||||
equation = TextElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
cluster=cluster,
|
||||
page_no=page.page_no,
|
||||
text=text,
|
||||
)
|
||||
elements.append(equation)
|
||||
body.append(equation)
|
||||
elif cluster.label in LayoutModel.CONTAINER_LABELS:
|
||||
container_el = ContainerElement(
|
||||
label=cluster.label,
|
||||
|
||||
@@ -20,7 +20,6 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TesseractOcrCliModel(BaseOcrModel):
|
||||
|
||||
def __init__(self, enabled: bool, options: TesseractCliOcrOptions):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
self.options: TesseractCliOcrOptions
|
||||
|
||||
Reference in New Issue
Block a user