mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 20:58:11 +00:00
feat: Code and equation model for PDF and code blocks in markdown (#752)
* propagated changes for new CodeItem class Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * Rebased branch on latest main. changes for CodeItem Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed unused files Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * chore: update lockfile Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * pin latest docling-core Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * update docling-core pinning Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * pin docling-core Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use new add_code in backends and update typing in MD backend Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * added if statement for backend Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed unused import Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed print statements Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * gt for new pdf Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * Update docling/pipeline/standard_pdf_pipeline.py Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Signed-off-by: Matteo <43417658+Matteo-Omenetti@users.noreply.github.com> * fixed doc comment of __call__ function of code_formula_model Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * fix artifacts_path type Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move imports Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move expansion_factor to base class Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Matteo <43417658+Matteo-Omenetti@users.noreply.github.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,6 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsciiDocBackend(DeclarativeDocumentBackend):
|
||||
|
||||
def __init__(self, in_doc: InputDocument, path_or_stream: Union[BytesIO, Path]):
|
||||
super().__init__(in_doc, path_or_stream)
|
||||
|
||||
|
||||
@@ -215,7 +215,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
label = DocItemLabel.CODE
|
||||
if len(text) == 0:
|
||||
return
|
||||
doc.add_text(parent=self.parents[self.level], label=label, text=text)
|
||||
doc.add_code(parent=self.parents[self.level], label=label, text=text)
|
||||
|
||||
def handle_paragraph(self, element, idx, doc):
|
||||
"""Handles paragraph tags (p)."""
|
||||
|
||||
@@ -3,19 +3,22 @@ import re
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Set, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import marko
|
||||
import marko.ext
|
||||
import marko.ext.gfm
|
||||
import marko.inline
|
||||
from docling_core.types.doc import (
|
||||
DocItem,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
DocumentOrigin,
|
||||
GroupLabel,
|
||||
NodeItem,
|
||||
TableCell,
|
||||
TableData,
|
||||
TextItem,
|
||||
)
|
||||
from marko import Markdown
|
||||
|
||||
@@ -27,8 +30,7 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
|
||||
def shorten_underscore_sequences(self, markdown_text, max_length=10):
|
||||
def shorten_underscore_sequences(self, markdown_text: str, max_length: int = 10):
|
||||
# This regex will match any sequence of underscores
|
||||
pattern = r"_+"
|
||||
|
||||
@@ -90,13 +92,13 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
) from e
|
||||
return
|
||||
|
||||
def close_table(self, doc=None):
|
||||
def close_table(self, doc: DoclingDocument):
|
||||
if self.in_table:
|
||||
_log.debug("=== TABLE START ===")
|
||||
for md_table_row in self.md_table_buffer:
|
||||
_log.debug(md_table_row)
|
||||
_log.debug("=== TABLE END ===")
|
||||
tcells = []
|
||||
tcells: List[TableCell] = []
|
||||
result_table = []
|
||||
for n, md_table_row in enumerate(self.md_table_buffer):
|
||||
data = []
|
||||
@@ -137,15 +139,19 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
self.in_table = False
|
||||
self.md_table_buffer = [] # clean table markdown buffer
|
||||
# Initialize Docling TableData
|
||||
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=tcells)
|
||||
table_data = TableData(
|
||||
num_rows=num_rows, num_cols=num_cols, table_cells=tcells
|
||||
)
|
||||
# Populate
|
||||
for tcell in tcells:
|
||||
data.table_cells.append(tcell)
|
||||
table_data.table_cells.append(tcell)
|
||||
if len(tcells) > 0:
|
||||
doc.add_table(data=data)
|
||||
doc.add_table(data=table_data)
|
||||
return
|
||||
|
||||
def process_inline_text(self, parent_element, doc=None):
|
||||
def process_inline_text(
|
||||
self, parent_element: Optional[NodeItem], doc: DoclingDocument
|
||||
):
|
||||
# self.inline_text_buffer += str(text_in)
|
||||
txt = self.inline_text_buffer.strip()
|
||||
if len(txt) > 0:
|
||||
@@ -156,14 +162,20 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
)
|
||||
self.inline_text_buffer = ""
|
||||
|
||||
def iterate_elements(self, element, depth=0, doc=None, parent_element=None):
|
||||
def iterate_elements(
|
||||
self,
|
||||
element: marko.block.Element,
|
||||
depth: int,
|
||||
doc: DoclingDocument,
|
||||
parent_element: Optional[NodeItem] = None,
|
||||
):
|
||||
# Iterates over all elements in the AST
|
||||
# Check for different element types and process relevant details
|
||||
if isinstance(element, marko.block.Heading):
|
||||
self.close_table(doc)
|
||||
self.process_inline_text(parent_element, doc)
|
||||
_log.debug(
|
||||
f" - Heading level {element.level}, content: {element.children[0].children}"
|
||||
f" - Heading level {element.level}, content: {element.children[0].children}" # type: ignore
|
||||
)
|
||||
if element.level == 1:
|
||||
doc_label = DocItemLabel.TITLE
|
||||
@@ -172,10 +184,10 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
|
||||
# Header could have arbitrary inclusion of bold, italic or emphasis,
|
||||
# hence we need to traverse the tree to get full text of a header
|
||||
strings = []
|
||||
strings: List[str] = []
|
||||
|
||||
# Define a recursive function to traverse the tree
|
||||
def traverse(node):
|
||||
def traverse(node: marko.block.BlockElement):
|
||||
# Check if the node has a "children" attribute
|
||||
if hasattr(node, "children"):
|
||||
# If "children" is a list, continue traversal
|
||||
@@ -209,9 +221,13 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
self.process_inline_text(parent_element, doc)
|
||||
_log.debug(" - List item")
|
||||
|
||||
snippet_text = str(element.children[0].children[0].children)
|
||||
snippet_text = str(element.children[0].children[0].children) # type: ignore
|
||||
is_numbered = False
|
||||
if parent_element.label == GroupLabel.ORDERED_LIST:
|
||||
if (
|
||||
parent_element is not None
|
||||
and isinstance(parent_element, DocItem)
|
||||
and parent_element.label == GroupLabel.ORDERED_LIST
|
||||
):
|
||||
is_numbered = True
|
||||
doc.add_list_item(
|
||||
enumerated=is_numbered, parent=parent_element, text=snippet_text
|
||||
@@ -221,7 +237,14 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
self.close_table(doc)
|
||||
self.process_inline_text(parent_element, doc)
|
||||
_log.debug(f" - Image with alt: {element.title}, url: {element.dest}")
|
||||
doc.add_picture(parent=parent_element, caption=element.title)
|
||||
|
||||
fig_caption: Optional[TextItem] = None
|
||||
if element.title is not None and element.title != "":
|
||||
fig_caption = doc.add_text(
|
||||
label=DocItemLabel.CAPTION, text=element.title
|
||||
)
|
||||
|
||||
doc.add_picture(parent=parent_element, caption=fig_caption)
|
||||
|
||||
elif isinstance(element, marko.block.Paragraph):
|
||||
self.process_inline_text(parent_element, doc)
|
||||
@@ -252,27 +275,21 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
|
||||
self.process_inline_text(parent_element, doc)
|
||||
_log.debug(f" - Code Span: {element.children}")
|
||||
snippet_text = str(element.children).strip()
|
||||
doc.add_text(
|
||||
label=DocItemLabel.CODE, parent=parent_element, text=snippet_text
|
||||
)
|
||||
doc.add_code(parent=parent_element, text=snippet_text)
|
||||
|
||||
elif isinstance(element, marko.block.CodeBlock):
|
||||
self.close_table(doc)
|
||||
self.process_inline_text(parent_element, doc)
|
||||
_log.debug(f" - Code Block: {element.children}")
|
||||
snippet_text = str(element.children[0].children).strip()
|
||||
doc.add_text(
|
||||
label=DocItemLabel.CODE, parent=parent_element, text=snippet_text
|
||||
)
|
||||
snippet_text = str(element.children[0].children).strip() # type: ignore
|
||||
doc.add_code(parent=parent_element, text=snippet_text)
|
||||
|
||||
elif isinstance(element, marko.block.FencedCode):
|
||||
self.close_table(doc)
|
||||
self.process_inline_text(parent_element, doc)
|
||||
_log.debug(f" - Code Block: {element.children}")
|
||||
snippet_text = str(element.children[0].children).strip()
|
||||
doc.add_text(
|
||||
label=DocItemLabel.CODE, parent=parent_element, text=snippet_text
|
||||
)
|
||||
snippet_text = str(element.children[0].children).strip() # type: ignore
|
||||
doc.add_code(parent=parent_element, text=snippet_text)
|
||||
|
||||
elif isinstance(element, marko.inline.LineBreak):
|
||||
self.process_inline_text(parent_element, doc)
|
||||
|
||||
@@ -44,7 +44,6 @@ class ExcelTable(BaseModel):
|
||||
|
||||
|
||||
class MsExcelDocumentBackend(DeclarativeDocumentBackend):
|
||||
|
||||
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
|
||||
super().__init__(in_doc, path_or_stream)
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MsWordDocumentBackend(DeclarativeDocumentBackend):
|
||||
|
||||
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
|
||||
super().__init__(in_doc, path_or_stream)
|
||||
self.XML_KEY = (
|
||||
|
||||
@@ -12,7 +12,6 @@ from docling.datamodel.document import InputDocument
|
||||
|
||||
|
||||
class PdfPageBackend(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_text_in_rect(self, bbox: BoundingBox) -> str:
|
||||
pass
|
||||
@@ -45,7 +44,6 @@ class PdfPageBackend(ABC):
|
||||
|
||||
|
||||
class PdfDocumentBackend(PaginatedDocumentBackend):
|
||||
|
||||
def __init__(self, in_doc: InputDocument, path_or_stream: Union[BytesIO, Path]):
|
||||
super().__init__(in_doc, path_or_stream)
|
||||
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@@ -225,6 +219,8 @@ class PdfPipelineOptions(PipelineOptions):
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
do_table_structure: bool = True # True: perform table structure extraction
|
||||
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
||||
do_code_enrichment: bool = False # True: perform code OCR
|
||||
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
|
||||
|
||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||
ocr_options: Union[
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, Iterable, Optional
|
||||
|
||||
from docling_core.types.doc import DoclingDocument, NodeItem, TextItem
|
||||
from docling_core.types.doc import BoundingBox, DoclingDocument, NodeItem, TextItem
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
||||
@@ -53,6 +53,7 @@ class BaseItemAndImageEnrichmentModel(
|
||||
):
|
||||
|
||||
images_scale: float
|
||||
expansion_factor: float = 0.0
|
||||
|
||||
def prepare_element(
|
||||
self, conv_res: ConversionResult, element: NodeItem
|
||||
@@ -62,8 +63,22 @@ class BaseItemAndImageEnrichmentModel(
|
||||
|
||||
assert isinstance(element, TextItem)
|
||||
element_prov = element.prov[0]
|
||||
|
||||
bbox = element_prov.bbox
|
||||
width = bbox.r - bbox.l
|
||||
height = bbox.t - bbox.b
|
||||
|
||||
# TODO: move to a utility in the BoundingBox class
|
||||
expanded_bbox = BoundingBox(
|
||||
l=bbox.l - width * self.expansion_factor,
|
||||
t=bbox.t + height * self.expansion_factor,
|
||||
r=bbox.r + width * self.expansion_factor,
|
||||
b=bbox.b - height * self.expansion_factor,
|
||||
coord_origin=bbox.coord_origin,
|
||||
)
|
||||
|
||||
page_ix = element_prov.page_no - 1
|
||||
cropped_image = conv_res.pages[page_ix].get_image(
|
||||
scale=self.images_scale, cropbox=element_prov.bbox
|
||||
scale=self.images_scale, cropbox=expanded_bbox
|
||||
)
|
||||
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)
|
||||
|
||||
245
docling/models/code_formula_model.py
Normal file
245
docling/models/code_formula_model.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from docling_core.types.doc import (
|
||||
CodeItem,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
NodeItem,
|
||||
TextItem,
|
||||
)
|
||||
from docling_core.types.doc.labels import CodeLanguageLabel
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
|
||||
class CodeFormulaModelOptions(BaseModel):
|
||||
"""
|
||||
Configuration options for the CodeFormulaModel.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
kind : str
|
||||
Type of the model. Fixed value "code_formula".
|
||||
do_code_enrichment : bool
|
||||
True if code enrichment is enabled, False otherwise.
|
||||
do_formula_enrichment : bool
|
||||
True if formula enrichment is enabled, False otherwise.
|
||||
"""
|
||||
|
||||
kind: Literal["code_formula"] = "code_formula"
|
||||
do_code_enrichment: bool = True
|
||||
do_formula_enrichment: bool = True
|
||||
|
||||
|
||||
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
"""
|
||||
Model for processing and enriching documents with code and formula predictions.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
enabled : bool
|
||||
True if the model is enabled, False otherwise.
|
||||
options : CodeFormulaModelOptions
|
||||
Configuration options for the CodeFormulaModel.
|
||||
code_formula_model : CodeFormulaPredictor
|
||||
The predictor model for code and formula processing.
|
||||
|
||||
Methods
|
||||
-------
|
||||
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
|
||||
Initializes the CodeFormulaModel with the given configuration options.
|
||||
is_processable(self, doc, element)
|
||||
Determines if a given element in a document can be processed by the model.
|
||||
__call__(self, doc, element_batch)
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
"""
|
||||
|
||||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
||||
expansion_factor = 0.03
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
options: CodeFormulaModelOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
"""
|
||||
Initializes the CodeFormulaModel with the given configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
enabled : bool
|
||||
True if the model is enabled, False otherwise.
|
||||
artifacts_path : Path
|
||||
Path to the directory containing the model artifacts.
|
||||
options : CodeFormulaModelOptions
|
||||
Configuration options for the model.
|
||||
accelerator_options : AcceleratorOptions
|
||||
Options specifying the device and number of threads for acceleration.
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.options = options
|
||||
|
||||
if self.enabled:
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
from docling_ibm_models.code_formula_model.code_formula_predictor import (
|
||||
CodeFormulaPredictor,
|
||||
)
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models_hf()
|
||||
else:
|
||||
artifacts_path = Path(artifacts_path)
|
||||
|
||||
self.code_formula_model = CodeFormulaPredictor(
|
||||
artifacts_path=artifacts_path,
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/CodeFormula",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v1.0.0",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
"""
|
||||
Determines if a given element in a document can be processed by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document being processed.
|
||||
element : NodeItem
|
||||
The element within the document to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the element can be processed, False otherwise.
|
||||
"""
|
||||
return self.enabled and (
|
||||
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
|
||||
or (
|
||||
isinstance(element, TextItem)
|
||||
and element.label == DocItemLabel.FORMULA
|
||||
and self.options.do_formula_enrichment
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
|
||||
"""Extracts a programming language from the beginning of a string.
|
||||
|
||||
This function checks if the input string starts with a pattern of the form
|
||||
``<_some_language_>``. If it does, it extracts the language string and returns
|
||||
a tuple of (remainder, language). Otherwise, it returns the original string
|
||||
and `None`.
|
||||
|
||||
Args:
|
||||
input_string (str): The input string, which may start with ``<_language_>``.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[str]]:
|
||||
A tuple where:
|
||||
- The first element is either:
|
||||
- The remainder of the string (everything after ``<_language_>``),
|
||||
if a match is found; or
|
||||
- The original string, if no match is found.
|
||||
- The second element is the extracted language if a match is found;
|
||||
otherwise, `None`.
|
||||
"""
|
||||
pattern = r"^<_([^>]+)_>\s*(.*)"
|
||||
match = re.match(pattern, input_string, flags=re.DOTALL)
|
||||
if match:
|
||||
language = str(match.group(1)) # the captured programming language
|
||||
remainder = str(match.group(2)) # everything after the <_language_>
|
||||
return remainder, language
|
||||
else:
|
||||
return input_string, None
|
||||
|
||||
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
|
||||
"""
|
||||
Converts a string to a corresponding `CodeLanguageLabel` enum member.
|
||||
|
||||
If the provided string does not match any value in `CodeLanguageLabel`,
|
||||
it defaults to `CodeLanguageLabel.UNKNOWN`.
|
||||
|
||||
Args:
|
||||
value (Optional[str]): The string representation of the code language or None.
|
||||
|
||||
Returns:
|
||||
CodeLanguageLabel: The corresponding enum member if the value is valid,
|
||||
otherwise `CodeLanguageLabel.UNKNOWN`.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return CodeLanguageLabel.UNKNOWN
|
||||
|
||||
try:
|
||||
return CodeLanguageLabel(value)
|
||||
except ValueError:
|
||||
return CodeLanguageLabel.UNKNOWN
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
doc: DoclingDocument,
|
||||
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
||||
) -> Iterable[NodeItem]:
|
||||
"""
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document being processed.
|
||||
element_batch : Iterable[ItemAndImageEnrichmentElement]
|
||||
A batch of elements to be processed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable[Any]
|
||||
An iterable of enriched elements.
|
||||
"""
|
||||
if not self.enabled:
|
||||
for element in element_batch:
|
||||
yield element.item
|
||||
return
|
||||
|
||||
labels: List[str] = []
|
||||
images: List[Image.Image] = []
|
||||
elements: List[TextItem] = []
|
||||
for el in element_batch:
|
||||
assert isinstance(el.item, TextItem)
|
||||
elements.append(el.item)
|
||||
labels.append(el.item.label)
|
||||
images.append(el.image)
|
||||
|
||||
outputs = self.code_formula_model.predict(images, labels)
|
||||
|
||||
for item, output in zip(elements, outputs):
|
||||
if isinstance(item, CodeItem):
|
||||
output, code_language = self._extract_code_language(output)
|
||||
item.code_language = self._get_code_language_enum(code_language)
|
||||
item.text = output
|
||||
|
||||
yield item
|
||||
@@ -40,7 +40,7 @@ class LayoutModel(BasePageModel):
|
||||
DocItemLabel.PAGE_FOOTER,
|
||||
DocItemLabel.CODE,
|
||||
DocItemLabel.LIST_ITEM,
|
||||
# "Formula",
|
||||
DocItemLabel.FORMULA,
|
||||
]
|
||||
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
||||
|
||||
|
||||
@@ -135,31 +135,6 @@ class PageAssembleModel(BasePageModel):
|
||||
)
|
||||
elements.append(fig)
|
||||
body.append(fig)
|
||||
elif cluster.label == LayoutModel.FORMULA_LABEL:
|
||||
equation = None
|
||||
if page.predictions.equations_prediction:
|
||||
equation = page.predictions.equations_prediction.equation_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
if (
|
||||
not equation
|
||||
): # fallback: add empty formula, if it isn't present
|
||||
text = self.sanitize_text(
|
||||
[
|
||||
cell.text.replace("\x02", "-").strip()
|
||||
for cell in cluster.cells
|
||||
if len(cell.text.strip()) > 0
|
||||
]
|
||||
)
|
||||
equation = TextElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
cluster=cluster,
|
||||
page_no=page.page_no,
|
||||
text=text,
|
||||
)
|
||||
elements.append(equation)
|
||||
body.append(equation)
|
||||
elif cluster.label in LayoutModel.CONTAINER_LABELS:
|
||||
container_el = ContainerElement(
|
||||
label=cluster.label,
|
||||
|
||||
@@ -20,7 +20,6 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TesseractOcrCliModel(BaseOcrModel):
|
||||
|
||||
def __init__(self, enabled: bool, options: TesseractCliOcrOptions):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
self.options: TesseractCliOcrOptions
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterable, List
|
||||
from typing import Any, Callable, Iterable, List
|
||||
|
||||
from docling_core.types.doc import DoclingDocument, NodeItem
|
||||
|
||||
@@ -18,7 +18,7 @@ from docling.datamodel.base_models import (
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import PipelineOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BaseEnrichmentModel
|
||||
from docling.models.base_model import GenericEnrichmentModel
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
from docling.utils.utils import chunkify
|
||||
|
||||
@@ -30,7 +30,7 @@ class BasePipeline(ABC):
|
||||
self.pipeline_options = pipeline_options
|
||||
self.keep_images = False
|
||||
self.build_pipe: List[Callable] = []
|
||||
self.enrichment_pipe: List[BaseEnrichmentModel] = []
|
||||
self.enrichment_pipe: List[GenericEnrichmentModel[Any]] = []
|
||||
|
||||
def execute(self, in_doc: InputDocument, raises_on_error: bool) -> ConversionResult:
|
||||
conv_res = ConversionResult(input=in_doc)
|
||||
@@ -66,7 +66,7 @@ class BasePipeline(ABC):
|
||||
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
|
||||
def _prepare_elements(
|
||||
conv_res: ConversionResult, model: BaseEnrichmentModel
|
||||
conv_res: ConversionResult, model: GenericEnrichmentModel[Any]
|
||||
) -> Iterable[NodeItem]:
|
||||
for doc_element, _level in conv_res.document.iterate_items():
|
||||
prepared_element = model.prepare_element(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
|
||||
|
||||
@@ -17,8 +17,8 @@ from docling.datamodel.pipeline_options import (
|
||||
TesseractCliOcrOptions,
|
||||
TesseractOcrOptions,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||
from docling.models.ds_glm_model import GlmModel, GlmOptions
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
@@ -93,8 +93,25 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
|
||||
self.enrichment_pipe = [
|
||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||
# Code Formula Enrichment Model
|
||||
CodeFormulaModel(
|
||||
enabled=pipeline_options.do_code_enrichment
|
||||
or pipeline_options.do_formula_enrichment,
|
||||
artifacts_path=pipeline_options.artifacts_path,
|
||||
options=CodeFormulaModelOptions(
|
||||
do_code_enrichment=pipeline_options.do_code_enrichment,
|
||||
do_formula_enrichment=pipeline_options.do_formula_enrichment,
|
||||
),
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
]
|
||||
|
||||
if (
|
||||
self.pipeline_options.do_formula_enrichment
|
||||
or self.pipeline_options.do_code_enrichment
|
||||
):
|
||||
self.keep_backend = True
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
|
||||
@@ -270,7 +270,6 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
|
||||
container_el = doc.add_group(label=group_label)
|
||||
|
||||
_add_child_elements(container_el, doc, obj, pelem)
|
||||
|
||||
elif "text" in obj:
|
||||
text = obj["text"][span_i:span_j]
|
||||
|
||||
@@ -304,6 +303,10 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
|
||||
current_list = None
|
||||
|
||||
doc.add_heading(text=text, prov=prov)
|
||||
elif label == DocItemLabel.CODE:
|
||||
current_list = None
|
||||
|
||||
doc.add_code(text=text, prov=prov)
|
||||
else:
|
||||
current_list = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user