added if statement for backend

Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com>
This commit is contained in:
Matteo Omenetti 2025-01-23 04:34:00 -05:00
parent d5b2c07295
commit 6206687e8b
8 changed files with 398 additions and 44 deletions

View File

@ -1,17 +1,11 @@
import logging
import os
import warnings
from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import Any, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
from typing_extensions import deprecated
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
_log = logging.getLogger(__name__)
@ -225,6 +219,8 @@ class PdfPipelineOptions(PipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
do_code_enrichment: bool = False # True: perform code OCR
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[

View File

@ -0,0 +1,274 @@
import re
from pathlib import Path
from typing import Iterable, List, Literal, Optional, Tuple
from docling_core.types.doc import CodeItem, DoclingDocument, NodeItem, TextItem
from docling_core.types.doc.base import BoundingBox
from docling_core.types.doc.labels import CodeLanguageLabel, DocItemLabel
from PIL import Image
from pydantic import BaseModel
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.utils.accelerator_utils import decide_device
class CodeFormulaModelOptions(BaseModel):
"""
Configuration options for the CodeFormulaModel.
Attributes
----------
kind : str
Type of the model. Fixed value "code_formula".
do_code_enrichment : bool
True if code enrichment is enabled, False otherwise.
do_formula_enrichment : bool
True if formula enrichment is enabled, False otherwise.
"""
kind: Literal["code_formula"] = "code_formula"
do_code_enrichment: bool = True
do_formula_enrichment: bool = True
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
"""
Model for processing and enriching documents with code and formula predictions.
Attributes
----------
enabled : bool
True if the model is enabled, False otherwise.
options : CodeFormulaModelOptions
Configuration options for the CodeFormulaModel.
code_formula_model : CodeFormulaPredictor
The predictor model for code and formula processing.
Methods
-------
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
Initializes the CodeFormulaModel with the given configuration options.
is_processable(self, doc, element)
Determines if a given element in a document can be processed by the model.
__call__(self, doc, element_batch)
Processes the given batch of elements and enriches them with predictions.
"""
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: CodeFormulaModelOptions,
accelerator_options: AcceleratorOptions,
):
"""
Initializes the CodeFormulaModel with the given configuration.
Parameters
----------
enabled : bool
True if the model is enabled, False otherwise.
artifacts_path : Path
Path to the directory containing the model artifacts.
options : CodeFormulaModelOptions
Configuration options for the model.
accelerator_options : AcceleratorOptions
Options specifying the device and number of threads for acceleration.
"""
self.enabled = enabled
self.options = options
if self.enabled:
device = decide_device(accelerator_options.device)
from docling_ibm_models.code_formula_model.code_formula_predictor import (
CodeFormulaPredictor,
)
if artifacts_path is None:
artifacts_path = self.download_models_hf()
self.code_formula_model = CodeFormulaPredictor(
artifacts_path=artifacts_path,
device=device,
num_threads=accelerator_options.num_threads,
)
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/CodeFormula",
force_download=force,
local_dir=local_dir,
revision="v1.0.0",
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if a given element in a document can be processed by the model.
Parameters
----------
doc : DoclingDocument
The document being processed.
element : NodeItem
The element within the document to check.
Returns
-------
bool
True if the element can be processed, False otherwise.
"""
return self.enabled and (
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
or (
isinstance(element, TextItem)
and element.label == DocItemLabel.FORMULA
and self.options.do_formula_enrichment
)
)
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
"""Extracts a programming language from the beginning of a (possibly multi-line) string.
This function checks if the input string starts with a pattern of the form
``<_some_language_>``. If it does, it extracts the language string and returns
a tuple of (remainder, language). Otherwise, it returns the original string
and `None`.
Args:
input_string (str): The input string, which may start with ``<_language_>``.
Returns:
Tuple[str, Optional[str]]:
A tuple where:
- The first element is either:
- The remainder of the string (everything after ``<_language_>``),
if a match is found; or
- The original string, if no match is found.
- The second element is the extracted language if a match is found;
otherwise, `None`.
"""
# Explanation of the regex:
# ^<_([^>]+)> : match "<_something>" at the start, capturing "something" (Group 1)
# \s* : optional whitespace
# (.*) : capture everything after that in Group 2
#
# We also use re.DOTALL so that the (.*) part can include newlines.
pattern = r"^<_([^>]+)_>\s*(.*)"
match = re.match(pattern, input_string, flags=re.DOTALL)
if match:
language = str(match.group(1)) # the captured programming language
remainder = str(match.group(2)) # everything after the <_language_>
return remainder, language
else:
return input_string, None
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
"""
Converts a string to a corresponding `CodeLanguageLabel` enum member.
If the provided string does not match any value in `CodeLanguageLabel`,
it defaults to `CodeLanguageLabel.UNKNOWN`.
Args:
value (Optional[str]): The string representation of the code language or None.
Returns:
CodeLanguageLabel: The corresponding enum member if the value is valid,
otherwise `CodeLanguageLabel.UNKNOWN`.
"""
if not isinstance(value, str):
return CodeLanguageLabel.UNKNOWN
try:
return CodeLanguageLabel(value)
except ValueError:
return CodeLanguageLabel.UNKNOWN
def prepare_element(
self, conv_res: ConversionResult, element: NodeItem
) -> Optional[ItemAndImageEnrichmentElement]:
if not self.is_processable(doc=conv_res.document, element=element):
return None
assert isinstance(element, TextItem)
element_prov = element.prov[0]
expansion_factor = 0.03 # Adjust the expansion percentage as needed
bbox = element_prov.bbox
width = bbox.r - bbox.l
height = bbox.t - bbox.b
# Create the expanded bounding box
expanded_bbox = BoundingBox(
l=bbox.l - width * expansion_factor, # Expand left
t=bbox.t + height * expansion_factor, # Expand top
r=bbox.r + width * expansion_factor, # Expand right
b=bbox.b - height * expansion_factor, # Expand bottom
coord_origin=bbox.coord_origin, # Preserve coordinate origin
)
page_ix = element_prov.page_no - 1
cropped_image = conv_res.pages[page_ix].get_image(
scale=self.images_scale, cropbox=expanded_bbox
)
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)
def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
"""
Processes the given batch of elements and enriches them with predictions.
Parameters
----------
doc : DoclingDocument
The document being processed.
element_batch : Iterable[NodeItem]
A batch of elements to be processed.
Returns
-------
Iterable[Any]
An iterable of enriched elements.
"""
if not self.enabled:
for element in element_batch:
yield element.item
return
labels: List[str] = []
images: List[Image.Image] = []
elements: List[TextItem] = []
for el in element_batch:
assert isinstance(el.item, TextItem)
elements.append(el.item)
labels.append(el.item.label)
images.append(el.image)
outputs = self.code_formula_model.predict(images, labels)
for item, output in zip(elements, outputs):
if isinstance(item, CodeItem):
output, code_language = self._extract_code_language(output)
item.code_language = self._get_code_language_enum(code_language)
item.text = output
yield item

View File

@ -3,7 +3,7 @@ import logging
import time
import traceback
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Any, Callable, Iterable, List
from docling_core.types.doc import DoclingDocument, NodeItem
@ -18,7 +18,7 @@ from docling.datamodel.base_models import (
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import PipelineOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BaseEnrichmentModel
from docling.models.base_model import BaseEnrichmentModel, GenericEnrichmentModel
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import chunkify
@ -30,7 +30,7 @@ class BasePipeline(ABC):
self.pipeline_options = pipeline_options
self.keep_images = False
self.build_pipe: List[Callable] = []
self.enrichment_pipe: List[BaseEnrichmentModel] = []
self.enrichment_pipe: List[GenericEnrichmentModel[Any]] = []
def execute(self, in_doc: InputDocument, raises_on_error: bool) -> ConversionResult:
conv_res = ConversionResult(input=in_doc)
@ -66,7 +66,7 @@ class BasePipeline(ABC):
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
def _prepare_elements(
conv_res: ConversionResult, model: BaseEnrichmentModel
conv_res: ConversionResult, model: GenericEnrichmentModel[Any]
) -> Iterable[NodeItem]:
for doc_element, _level in conv_res.document.iterate_items():
prepared_element = model.prepare_element(

View File

@ -1,7 +1,7 @@
import logging
import sys
from pathlib import Path
from typing import Iterable, Optional
from typing import Optional
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
@ -17,8 +17,8 @@ from docling.datamodel.pipeline_options import (
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.models.base_model import BasePageModel
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.ds_glm_model import GlmModel, GlmOptions
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
@ -93,8 +93,25 @@ class StandardPdfPipeline(PaginatedPipeline):
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
# Code Formula Enrichment Model
CodeFormulaModel(
enabled=pipeline_options.do_code_enrichment
or pipeline_options.do_formula_enrichment,
artifacts_path=None,
options=CodeFormulaModelOptions(
do_code_enrichment=pipeline_options.do_code_enrichment,
do_formula_enrichment=pipeline_options.do_formula_enrichment,
),
accelerator_options=pipeline_options.accelerator_options,
),
]
if (
self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment
):
self.keep_backend = True
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False

56
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@ -888,13 +888,13 @@ chunking = ["semchunk (>=2.2.0,<3.0.0)", "transformers (>=4.34.0,<5.0.0)"]
[[package]]
name = "docling-ibm-models"
version = "3.2.0"
version = "3.2.1"
description = "This package contains the AI models used by the Docling PDF conversion package"
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "docling_ibm_models-3.2.0-py3-none-any.whl", hash = "sha256:9de784dc00f8e6db4f2acaf934bad133477b5a230f19d030e3d9ebb44e453c8e"},
{file = "docling_ibm_models-3.2.0.tar.gz", hash = "sha256:b0329256fb1464d51854f1654a4e09cbb812edfeaa104b45677952c7135c5ef8"},
{file = "docling_ibm_models-3.2.1-py3-none-any.whl", hash = "sha256:55bca5673381cc5862f4de584345020d071414c46bc1b9f6436d674e3610ec97"},
{file = "docling_ibm_models-3.2.1.tar.gz", hash = "sha256:abd1bdc58f00600065eedbfbd34876704d5004cd20884a2c0a61ca2ee5a927dd"},
]
[package.dependencies]
@ -1074,18 +1074,18 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc
[[package]]
name = "filelock"
version = "3.16.1"
version = "3.17.0"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.8"
python-versions = ">=3.9"
files = [
{file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"},
{file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"},
{file = "filelock-3.17.0-py3-none-any.whl", hash = "sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338"},
{file = "filelock-3.17.0.tar.gz", hash = "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e"},
]
[package.extras]
docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"]
testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"]
docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"]
typing = ["typing-extensions (>=4.12.2)"]
[[package]]
@ -1134,13 +1134,13 @@ dev = ["pyTest", "pyTest-cov"]
[[package]]
name = "flatbuffers"
version = "24.12.23"
version = "25.1.21"
description = "The FlatBuffers serialization format for Python"
optional = true
python-versions = "*"
files = [
{file = "flatbuffers-24.12.23-py2.py3-none-any.whl", hash = "sha256:c418e0d48890f4142b92fd3e343e73a48f194e1f80075ddcc5793779b3585444"},
{file = "flatbuffers-24.12.23.tar.gz", hash = "sha256:2910b0bc6ae9b6db78dd2b18d0b7a0709ba240fb5585f286a3a2b30785c22dac"},
{file = "flatbuffers-25.1.21-py2.py3-none-any.whl", hash = "sha256:0e9736098ba8f4e48246a0640390f4992c0b1a734e7322a9463d5c3eea00558b"},
{file = "flatbuffers-25.1.21.tar.gz", hash = "sha256:e24a34dcd9fb4e0ea8cc0fc8ef9c5cd61c9d21527a6d536967587a37a4ff9676"},
]
[[package]]
@ -3823,10 +3823,10 @@ files = [
numpy = [
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
[[package]]
@ -3849,10 +3849,10 @@ files = [
numpy = [
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
[[package]]
@ -4037,8 +4037,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -4818,13 +4818,13 @@ testutils = ["gitpython (>3)"]
[[package]]
name = "pymdown-extensions"
version = "10.14"
version = "10.14.1"
description = "Extension pack for Python Markdown."
optional = false
python-versions = ">=3.8"
files = [
{file = "pymdown_extensions-10.14-py3-none-any.whl", hash = "sha256:202481f716cc8250e4be8fce997781ebf7917701b59652458ee47f2401f818b5"},
{file = "pymdown_extensions-10.14.tar.gz", hash = "sha256:741bd7c4ff961ba40b7528d32284c53bc436b8b1645e8e37c3e57770b8700a34"},
{file = "pymdown_extensions-10.14.1-py3-none-any.whl", hash = "sha256:637951cbfbe9874ba28134fb3ce4b8bcadd6aca89ac4998ec29dcbafd554ae08"},
{file = "pymdown_extensions-10.14.1.tar.gz", hash = "sha256:b65801996a0cd4f42a3110810c306c45b7313c09b0610a6f773730f2a9e3c96b"},
]
[package.dependencies]
@ -4836,13 +4836,13 @@ extra = ["pygments (>=2.19.1)"]
[[package]]
name = "pymilvus"
version = "2.5.3"
version = "2.5.4"
description = "Python Sdk for Milvus"
optional = false
python-versions = ">=3.8"
files = [
{file = "pymilvus-2.5.3-py3-none-any.whl", hash = "sha256:64ca63594284586937274800be27a402f3be2d078130bf81d94ab8d7798ac9c8"},
{file = "pymilvus-2.5.3.tar.gz", hash = "sha256:68bc3797b7a14c494caf116cee888894ffd6eba7b96a3ac841be85d60694cc5d"},
{file = "pymilvus-2.5.4-py3-none-any.whl", hash = "sha256:3f7ddaeae0c8f63554b8e316b73f265d022e05a457d47c366ce47293434a3aea"},
{file = "pymilvus-2.5.4.tar.gz", hash = "sha256:611732428ff669d57ded3d1f823bdeb10febf233d0251cce8498b287e5a10ce8"},
]
[package.dependencies]
@ -7175,13 +7175,13 @@ files = [
[[package]]
name = "tzdata"
version = "2024.2"
version = "2025.1"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
{file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"},
{file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"},
{file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"},
{file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"},
]
[[package]]
@ -7751,4 +7751,4 @@ tesserocr = ["tesserocr"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "d326d96280673e5b8bed4536aea277fe88eb503872a91c59d11e43e859b003e6"
content-hash = "8bb0b67294a50c0340c5cc02ce60d3608ef4d1968ae50f7e0b8b4c8a26c34734"

View File

@ -27,7 +27,7 @@ packages = [{include = "docling"}]
python = "^3.9"
pydantic = "^2.0.0"
docling-core = { version = "^2.15.1", extras = ["chunking"] }
docling-ibm-models = "^3.1.0"
docling-ibm-models = "^3.2.1"
deepsearch-glm = "^1.0.0"
docling-parse = "^3.1.0"
filetype = "^1.2.0"

Binary file not shown.

View File

@ -0,0 +1,67 @@
from pathlib import Path
from docling_core.types.doc import CodeItem, TextItem
from docling_core.types.doc.labels import CodeLanguageLabel, DocItemLabel
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
def get_converter():
pipeline_options = PdfPipelineOptions()
pipeline_options.generate_page_images = True
pipeline_options.do_ocr = False
pipeline_options.do_table_structure = False
pipeline_options.do_code_enrichment = True
pipeline_options.do_formula_enrichment = True
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
backend=DoclingParseV2DocumentBackend,
pipeline_cls=StandardPdfPipeline,
pipeline_options=pipeline_options,
)
}
)
return converter
def test_code_and_formula_conversion():
pdf_path = Path("tests/data/code_and_formula.pdf")
converter = get_converter()
print(f"converting {pdf_path}")
doc_result: ConversionResult = converter.convert(pdf_path)
results = doc_result.document.texts
code_blocks = [el for el in results if isinstance(el, CodeItem)]
assert len(code_blocks) == 1
gt = 'public static void print() {\n System.out.println("Java Code");\n}'
print(gt)
predicted = code_blocks[0].text.strip()
assert predicted == gt, f"mismatch in text {predicted=}, {gt=}"
assert code_blocks[0].code_language == CodeLanguageLabel.JAVA
formula_blocks = [
el
for el in results
if isinstance(el, TextItem) and el.label == DocItemLabel.FORMULA
]
assert len(formula_blocks) == 1
gt = "a ^ { 2 } + 8 = 1 2"
predicted = formula_blocks[0].text
assert predicted == gt, f"mismatch in text {predicted=}, {gt=}"