diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index eeec6bab..efdf3b1c 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -1,17 +1,11 @@ import logging import os -import warnings from enum import Enum from pathlib import Path -from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Any, List, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from pydantic_settings import ( - BaseSettings, - PydanticBaseSettingsSource, - SettingsConfigDict, -) -from typing_extensions import deprecated +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict _log = logging.getLogger(__name__) @@ -225,6 +219,8 @@ class PdfPipelineOptions(PipelineOptions): artifacts_path: Optional[Union[Path, str]] = None do_table_structure: bool = True # True: perform table structure extraction do_ocr: bool = True # True: perform OCR, replace programmatic PDF text + do_code_enrichment: bool = False # True: perform code OCR + do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code table_structure_options: TableStructureOptions = TableStructureOptions() ocr_options: Union[ diff --git a/docling/models/code_formula_model.py b/docling/models/code_formula_model.py new file mode 100644 index 00000000..ca48980d --- /dev/null +++ b/docling/models/code_formula_model.py @@ -0,0 +1,274 @@ +import re +from pathlib import Path +from typing import Iterable, List, Literal, Optional, Tuple + +from docling_core.types.doc import CodeItem, DoclingDocument, NodeItem, TextItem +from docling_core.types.doc.base import BoundingBox +from docling_core.types.doc.labels import CodeLanguageLabel, DocItemLabel +from PIL import Image +from pydantic import BaseModel + +from docling.datamodel.base_models import ItemAndImageEnrichmentElement +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import AcceleratorOptions +from docling.models.base_model import BaseItemAndImageEnrichmentModel +from docling.utils.accelerator_utils import decide_device + + +class CodeFormulaModelOptions(BaseModel): + """ + Configuration options for the CodeFormulaModel. + + Attributes + ---------- + kind : str + Type of the model. Fixed value "code_formula". + do_code_enrichment : bool + True if code enrichment is enabled, False otherwise. + do_formula_enrichment : bool + True if formula enrichment is enabled, False otherwise. + """ + + kind: Literal["code_formula"] = "code_formula" + do_code_enrichment: bool = True + do_formula_enrichment: bool = True + + +class CodeFormulaModel(BaseItemAndImageEnrichmentModel): + """ + Model for processing and enriching documents with code and formula predictions. + + Attributes + ---------- + enabled : bool + True if the model is enabled, False otherwise. + options : CodeFormulaModelOptions + Configuration options for the CodeFormulaModel. + code_formula_model : CodeFormulaPredictor + The predictor model for code and formula processing. + + Methods + ------- + __init__(self, enabled, artifacts_path, accelerator_options, code_formula_options) + Initializes the CodeFormulaModel with the given configuration options. + is_processable(self, doc, element) + Determines if a given element in a document can be processed by the model. + __call__(self, doc, element_batch) + Processes the given batch of elements and enriches them with predictions. + """ + + images_scale = 1.66 # = 120 dpi, aligned with training data resolution + + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: CodeFormulaModelOptions, + accelerator_options: AcceleratorOptions, + ): + """ + Initializes the CodeFormulaModel with the given configuration. + + Parameters + ---------- + enabled : bool + True if the model is enabled, False otherwise. + artifacts_path : Path + Path to the directory containing the model artifacts. + options : CodeFormulaModelOptions + Configuration options for the model. + accelerator_options : AcceleratorOptions + Options specifying the device and number of threads for acceleration. + """ + self.enabled = enabled + self.options = options + + if self.enabled: + device = decide_device(accelerator_options.device) + + from docling_ibm_models.code_formula_model.code_formula_predictor import ( + CodeFormulaPredictor, + ) + + if artifacts_path is None: + artifacts_path = self.download_models_hf() + + self.code_formula_model = CodeFormulaPredictor( + artifacts_path=artifacts_path, + device=device, + num_threads=accelerator_options.num_threads, + ) + + @staticmethod + def download_models_hf( + local_dir: Optional[Path] = None, force: bool = False + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + disable_progress_bars() + download_path = snapshot_download( + repo_id="ds4sd/CodeFormula", + force_download=force, + local_dir=local_dir, + revision="v1.0.0", + ) + + return Path(download_path) + + def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: + """ + Determines if a given element in a document can be processed by the model. + + Parameters + ---------- + doc : DoclingDocument + The document being processed. + element : NodeItem + The element within the document to check. + + Returns + ------- + bool + True if the element can be processed, False otherwise. + """ + return self.enabled and ( + (isinstance(element, CodeItem) and self.options.do_code_enrichment) + or ( + isinstance(element, TextItem) + and element.label == DocItemLabel.FORMULA + and self.options.do_formula_enrichment + ) + ) + + def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]: + """Extracts a programming language from the beginning of a (possibly multi-line) string. + + This function checks if the input string starts with a pattern of the form + ``<_some_language_>``. If it does, it extracts the language string and returns + a tuple of (remainder, language). Otherwise, it returns the original string + and `None`. + + Args: + input_string (str): The input string, which may start with ``<_language_>``. + + Returns: + Tuple[str, Optional[str]]: + A tuple where: + - The first element is either: + - The remainder of the string (everything after ``<_language_>``), + if a match is found; or + - The original string, if no match is found. + - The second element is the extracted language if a match is found; + otherwise, `None`. + """ + # Explanation of the regex: + # ^<_([^>]+)> : match "<_something>" at the start, capturing "something" (Group 1) + # \s* : optional whitespace + # (.*) : capture everything after that in Group 2 + # + # We also use re.DOTALL so that the (.*) part can include newlines. + pattern = r"^<_([^>]+)_>\s*(.*)" + match = re.match(pattern, input_string, flags=re.DOTALL) + if match: + language = str(match.group(1)) # the captured programming language + remainder = str(match.group(2)) # everything after the <_language_> + return remainder, language + else: + return input_string, None + + def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel: + """ + Converts a string to a corresponding `CodeLanguageLabel` enum member. + + If the provided string does not match any value in `CodeLanguageLabel`, + it defaults to `CodeLanguageLabel.UNKNOWN`. + + Args: + value (Optional[str]): The string representation of the code language or None. + + Returns: + CodeLanguageLabel: The corresponding enum member if the value is valid, + otherwise `CodeLanguageLabel.UNKNOWN`. + """ + if not isinstance(value, str): + return CodeLanguageLabel.UNKNOWN + + try: + return CodeLanguageLabel(value) + except ValueError: + return CodeLanguageLabel.UNKNOWN + + def prepare_element( + self, conv_res: ConversionResult, element: NodeItem + ) -> Optional[ItemAndImageEnrichmentElement]: + if not self.is_processable(doc=conv_res.document, element=element): + return None + + assert isinstance(element, TextItem) + + element_prov = element.prov[0] + + expansion_factor = 0.03 # Adjust the expansion percentage as needed + bbox = element_prov.bbox + width = bbox.r - bbox.l + height = bbox.t - bbox.b + + # Create the expanded bounding box + expanded_bbox = BoundingBox( + l=bbox.l - width * expansion_factor, # Expand left + t=bbox.t + height * expansion_factor, # Expand top + r=bbox.r + width * expansion_factor, # Expand right + b=bbox.b - height * expansion_factor, # Expand bottom + coord_origin=bbox.coord_origin, # Preserve coordinate origin + ) + + page_ix = element_prov.page_no - 1 + cropped_image = conv_res.pages[page_ix].get_image( + scale=self.images_scale, cropbox=expanded_bbox + ) + return ItemAndImageEnrichmentElement(item=element, image=cropped_image) + + def __call__( + self, + doc: DoclingDocument, + element_batch: Iterable[ItemAndImageEnrichmentElement], + ) -> Iterable[NodeItem]: + """ + Processes the given batch of elements and enriches them with predictions. + + Parameters + ---------- + doc : DoclingDocument + The document being processed. + element_batch : Iterable[NodeItem] + A batch of elements to be processed. + + Returns + ------- + Iterable[Any] + An iterable of enriched elements. + """ + if not self.enabled: + for element in element_batch: + yield element.item + return + + labels: List[str] = [] + images: List[Image.Image] = [] + elements: List[TextItem] = [] + for el in element_batch: + assert isinstance(el.item, TextItem) + elements.append(el.item) + labels.append(el.item.label) + images.append(el.image) + + outputs = self.code_formula_model.predict(images, labels) + + for item, output in zip(elements, outputs): + if isinstance(item, CodeItem): + output, code_language = self._extract_code_language(output) + item.code_language = self._get_code_language_enum(code_language) + item.text = output + + yield item diff --git a/docling/pipeline/base_pipeline.py b/docling/pipeline/base_pipeline.py index 034e6d42..656c3bc3 100644 --- a/docling/pipeline/base_pipeline.py +++ b/docling/pipeline/base_pipeline.py @@ -3,7 +3,7 @@ import logging import time import traceback from abc import ABC, abstractmethod -from typing import Callable, Iterable, List +from typing import Any, Callable, Iterable, List from docling_core.types.doc import DoclingDocument, NodeItem @@ -18,7 +18,7 @@ from docling.datamodel.base_models import ( from docling.datamodel.document import ConversionResult, InputDocument from docling.datamodel.pipeline_options import PipelineOptions from docling.datamodel.settings import settings -from docling.models.base_model import BaseEnrichmentModel +from docling.models.base_model import BaseEnrichmentModel, GenericEnrichmentModel from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.utils import chunkify @@ -30,7 +30,7 @@ class BasePipeline(ABC): self.pipeline_options = pipeline_options self.keep_images = False self.build_pipe: List[Callable] = [] - self.enrichment_pipe: List[BaseEnrichmentModel] = [] + self.enrichment_pipe: List[GenericEnrichmentModel[Any]] = [] def execute(self, in_doc: InputDocument, raises_on_error: bool) -> ConversionResult: conv_res = ConversionResult(input=in_doc) @@ -66,7 +66,7 @@ class BasePipeline(ABC): def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult: def _prepare_elements( - conv_res: ConversionResult, model: BaseEnrichmentModel + conv_res: ConversionResult, model: GenericEnrichmentModel[Any] ) -> Iterable[NodeItem]: for doc_element, _level in conv_res.document.iterate_items(): prepared_element = model.prepare_element( diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 758f4e94..0340a51b 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -1,7 +1,7 @@ import logging import sys from pathlib import Path -from typing import Iterable, Optional +from typing import Optional from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem @@ -17,8 +17,8 @@ from docling.datamodel.pipeline_options import ( TesseractCliOcrOptions, TesseractOcrOptions, ) -from docling.models.base_model import BasePageModel from docling.models.base_ocr_model import BaseOcrModel +from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions from docling.models.ds_glm_model import GlmModel, GlmOptions from docling.models.easyocr_model import EasyOcrModel from docling.models.layout_model import LayoutModel @@ -93,8 +93,25 @@ class StandardPdfPipeline(PaginatedPipeline): self.enrichment_pipe = [ # Other models working on `NodeItem` elements in the DoclingDocument + # Code Formula Enrichment Model + CodeFormulaModel( + enabled=pipeline_options.do_code_enrichment + or pipeline_options.do_formula_enrichment, + artifacts_path=None, + options=CodeFormulaModelOptions( + do_code_enrichment=pipeline_options.do_code_enrichment, + do_formula_enrichment=pipeline_options.do_formula_enrichment, + ), + accelerator_options=pipeline_options.accelerator_options, + ), ] + if ( + self.pipeline_options.do_formula_enrichment + or self.pipeline_options.do_code_enrichment + ): + self.keep_backend = True + @staticmethod def download_models_hf( local_dir: Optional[Path] = None, force: bool = False diff --git a/poetry.lock b/poetry.lock index 1ae80adb..73fc85db 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -888,13 +888,13 @@ chunking = ["semchunk (>=2.2.0,<3.0.0)", "transformers (>=4.34.0,<5.0.0)"] [[package]] name = "docling-ibm-models" -version = "3.2.0" +version = "3.2.1" description = "This package contains the AI models used by the Docling PDF conversion package" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "docling_ibm_models-3.2.0-py3-none-any.whl", hash = "sha256:9de784dc00f8e6db4f2acaf934bad133477b5a230f19d030e3d9ebb44e453c8e"}, - {file = "docling_ibm_models-3.2.0.tar.gz", hash = "sha256:b0329256fb1464d51854f1654a4e09cbb812edfeaa104b45677952c7135c5ef8"}, + {file = "docling_ibm_models-3.2.1-py3-none-any.whl", hash = "sha256:55bca5673381cc5862f4de584345020d071414c46bc1b9f6436d674e3610ec97"}, + {file = "docling_ibm_models-3.2.1.tar.gz", hash = "sha256:abd1bdc58f00600065eedbfbd34876704d5004cd20884a2c0a61ca2ee5a927dd"}, ] [package.dependencies] @@ -1074,18 +1074,18 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.16.1" +version = "3.17.0" description = "A platform independent file lock." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, - {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, + {file = "filelock-3.17.0-py3-none-any.whl", hash = "sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338"}, + {file = "filelock-3.17.0.tar.gz", hash = "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e"}, ] [package.extras] -docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"] typing = ["typing-extensions (>=4.12.2)"] [[package]] @@ -1134,13 +1134,13 @@ dev = ["pyTest", "pyTest-cov"] [[package]] name = "flatbuffers" -version = "24.12.23" +version = "25.1.21" description = "The FlatBuffers serialization format for Python" optional = true python-versions = "*" files = [ - {file = "flatbuffers-24.12.23-py2.py3-none-any.whl", hash = "sha256:c418e0d48890f4142b92fd3e343e73a48f194e1f80075ddcc5793779b3585444"}, - {file = "flatbuffers-24.12.23.tar.gz", hash = "sha256:2910b0bc6ae9b6db78dd2b18d0b7a0709ba240fb5585f286a3a2b30785c22dac"}, + {file = "flatbuffers-25.1.21-py2.py3-none-any.whl", hash = "sha256:0e9736098ba8f4e48246a0640390f4992c0b1a734e7322a9463d5c3eea00558b"}, + {file = "flatbuffers-25.1.21.tar.gz", hash = "sha256:e24a34dcd9fb4e0ea8cc0fc8ef9c5cd61c9d21527a6d536967587a37a4ff9676"}, ] [[package]] @@ -3823,10 +3823,10 @@ files = [ numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -3849,10 +3849,10 @@ files = [ numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4037,8 +4037,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4818,13 +4818,13 @@ testutils = ["gitpython (>3)"] [[package]] name = "pymdown-extensions" -version = "10.14" +version = "10.14.1" description = "Extension pack for Python Markdown." optional = false python-versions = ">=3.8" files = [ - {file = "pymdown_extensions-10.14-py3-none-any.whl", hash = "sha256:202481f716cc8250e4be8fce997781ebf7917701b59652458ee47f2401f818b5"}, - {file = "pymdown_extensions-10.14.tar.gz", hash = "sha256:741bd7c4ff961ba40b7528d32284c53bc436b8b1645e8e37c3e57770b8700a34"}, + {file = "pymdown_extensions-10.14.1-py3-none-any.whl", hash = "sha256:637951cbfbe9874ba28134fb3ce4b8bcadd6aca89ac4998ec29dcbafd554ae08"}, + {file = "pymdown_extensions-10.14.1.tar.gz", hash = "sha256:b65801996a0cd4f42a3110810c306c45b7313c09b0610a6f773730f2a9e3c96b"}, ] [package.dependencies] @@ -4836,13 +4836,13 @@ extra = ["pygments (>=2.19.1)"] [[package]] name = "pymilvus" -version = "2.5.3" +version = "2.5.4" description = "Python Sdk for Milvus" optional = false python-versions = ">=3.8" files = [ - {file = "pymilvus-2.5.3-py3-none-any.whl", hash = "sha256:64ca63594284586937274800be27a402f3be2d078130bf81d94ab8d7798ac9c8"}, - {file = "pymilvus-2.5.3.tar.gz", hash = "sha256:68bc3797b7a14c494caf116cee888894ffd6eba7b96a3ac841be85d60694cc5d"}, + {file = "pymilvus-2.5.4-py3-none-any.whl", hash = "sha256:3f7ddaeae0c8f63554b8e316b73f265d022e05a457d47c366ce47293434a3aea"}, + {file = "pymilvus-2.5.4.tar.gz", hash = "sha256:611732428ff669d57ded3d1f823bdeb10febf233d0251cce8498b287e5a10ce8"}, ] [package.dependencies] @@ -7175,13 +7175,13 @@ files = [ [[package]] name = "tzdata" -version = "2024.2" +version = "2025.1" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" files = [ - {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, - {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, + {file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"}, + {file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"}, ] [[package]] @@ -7751,4 +7751,4 @@ tesserocr = ["tesserocr"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d326d96280673e5b8bed4536aea277fe88eb503872a91c59d11e43e859b003e6" +content-hash = "8bb0b67294a50c0340c5cc02ce60d3608ef4d1968ae50f7e0b8b4c8a26c34734" diff --git a/pyproject.toml b/pyproject.toml index 237b080d..c3e1fa67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ packages = [{include = "docling"}] python = "^3.9" pydantic = "^2.0.0" docling-core = { version = "^2.15.1", extras = ["chunking"] } -docling-ibm-models = "^3.1.0" +docling-ibm-models = "^3.2.1" deepsearch-glm = "^1.0.0" docling-parse = "^3.1.0" filetype = "^1.2.0" diff --git a/tests/data/code_and_formula.pdf b/tests/data/code_and_formula.pdf new file mode 100644 index 00000000..82cd8343 Binary files /dev/null and b/tests/data/code_and_formula.pdf differ diff --git a/tests/test_code_formula.py b/tests/test_code_formula.py new file mode 100644 index 00000000..f7843286 --- /dev/null +++ b/tests/test_code_formula.py @@ -0,0 +1,67 @@ +from pathlib import Path + +from docling_core.types.doc import CodeItem, TextItem +from docling_core.types.doc.labels import CodeLanguageLabel, DocItemLabel + +from docling.backend.docling_parse_backend import DoclingParseDocumentBackend +from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBackend +from docling.datamodel.base_models import InputFormat +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import PdfPipelineOptions +from docling.document_converter import DocumentConverter, PdfFormatOption +from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline + + +def get_converter(): + + pipeline_options = PdfPipelineOptions() + pipeline_options.generate_page_images = True + + pipeline_options.do_ocr = False + pipeline_options.do_table_structure = False + pipeline_options.do_code_enrichment = True + pipeline_options.do_formula_enrichment = True + + converter = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption( + backend=DoclingParseV2DocumentBackend, + pipeline_cls=StandardPdfPipeline, + pipeline_options=pipeline_options, + ) + } + ) + + return converter + + +def test_code_and_formula_conversion(): + pdf_path = Path("tests/data/code_and_formula.pdf") + converter = get_converter() + + print(f"converting {pdf_path}") + + doc_result: ConversionResult = converter.convert(pdf_path) + + results = doc_result.document.texts + + code_blocks = [el for el in results if isinstance(el, CodeItem)] + assert len(code_blocks) == 1 + + gt = 'public static void print() {\n System.out.println("Java Code");\n}' + print(gt) + + predicted = code_blocks[0].text.strip() + assert predicted == gt, f"mismatch in text {predicted=}, {gt=}" + assert code_blocks[0].code_language == CodeLanguageLabel.JAVA + + formula_blocks = [ + el + for el in results + if isinstance(el, TextItem) and el.label == DocItemLabel.FORMULA + ] + assert len(formula_blocks) == 1 + + gt = "a ^ { 2 } + 8 = 1 2" + predicted = formula_blocks[0].text + assert predicted == gt, f"mismatch in text {predicted=}, {gt=}"