mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
test: improve typing definitions (part 1)
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
79932b7d69
commit
594dc84245
@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docling.datamodel.base_models import BoundingBox, Cell, PageSize
|
||||
|
||||
|
||||
class PdfPageBackend(ABC):
|
||||
|
||||
@ -17,12 +20,12 @@ class PdfPageBackend(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_bitmap_rects(self, scale: int = 1) -> Iterable["BoundingBox"]:
|
||||
def get_bitmap_rects(self, float: int = 1) -> Iterable["BoundingBox"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_page_image(
|
||||
self, scale: int = 1, cropbox: Optional["BoundingBox"] = None
|
||||
self, scale: float = 1, cropbox: Optional["BoundingBox"] = None
|
||||
) -> Image.Image:
|
||||
pass
|
||||
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import random
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
import pypdfium2 as pdfium
|
||||
from docling_parse.docling_parse import pdf_parser
|
||||
@ -22,7 +22,6 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
self._ppage = page_obj
|
||||
parsed_page = parser.parse_pdf_from_key_on_page(document_hash, page_no)
|
||||
|
||||
self._dpage = None
|
||||
self.valid = "pages" in parsed_page
|
||||
if self.valid:
|
||||
self._dpage = parsed_page["pages"][0]
|
||||
@ -68,7 +67,7 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
return text_piece
|
||||
|
||||
def get_text_cells(self) -> Iterable[Cell]:
|
||||
cells = []
|
||||
cells: List[Cell] = []
|
||||
cell_counter = 0
|
||||
|
||||
if not self.valid:
|
||||
@ -130,7 +129,7 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
|
||||
return cells
|
||||
|
||||
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
|
||||
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
|
||||
AREA_THRESHOLD = 32 * 32
|
||||
|
||||
for i in range(len(self._dpage["images"])):
|
||||
@ -145,7 +144,7 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
yield cropbox
|
||||
|
||||
def get_page_image(
|
||||
self, scale: int = 1, cropbox: Optional[BoundingBox] = None
|
||||
self, scale: float = 1, cropbox: Optional[BoundingBox] = None
|
||||
) -> Image.Image:
|
||||
|
||||
page_size = self.get_size()
|
||||
|
@ -7,7 +7,7 @@ from typing import Iterable, List, Optional, Union
|
||||
import pypdfium2 as pdfium
|
||||
import pypdfium2.raw as pdfium_c
|
||||
from PIL import Image, ImageDraw
|
||||
from pypdfium2 import PdfPage
|
||||
from pypdfium2 import PdfPage, PdfTextPage
|
||||
from pypdfium2._helpers.misc import PdfiumError
|
||||
|
||||
from docling.backend.abstract_backend import PdfDocumentBackend, PdfPageBackend
|
||||
@ -29,12 +29,12 @@ class PyPdfiumPageBackend(PdfPageBackend):
|
||||
exc_info=True,
|
||||
)
|
||||
self.valid = False
|
||||
self.text_page = None
|
||||
self.text_page: Optional[PdfTextPage] = None
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.valid
|
||||
|
||||
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
|
||||
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
|
||||
AREA_THRESHOLD = 32 * 32
|
||||
for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]):
|
||||
pos = obj.get_pos()
|
||||
@ -189,7 +189,7 @@ class PyPdfiumPageBackend(PdfPageBackend):
|
||||
return cells
|
||||
|
||||
def get_page_image(
|
||||
self, scale: int = 1, cropbox: Optional[BoundingBox] = None
|
||||
self, scale: float = 1, cropbox: Optional[BoundingBox] = None
|
||||
) -> Image.Image:
|
||||
|
||||
page_size = self.get_size()
|
||||
|
@ -87,7 +87,7 @@ class BoundingBox(BaseModel):
|
||||
return (self.l, self.b, self.r, self.t)
|
||||
|
||||
@classmethod
|
||||
def from_tuple(cls, coord: Tuple[float], origin: CoordOrigin):
|
||||
def from_tuple(cls, coord: Tuple[float, ...], origin: CoordOrigin):
|
||||
if origin == CoordOrigin.TOPLEFT:
|
||||
l, t, r, b = coord[0], coord[1], coord[2], coord[3]
|
||||
if r < l:
|
||||
@ -246,7 +246,7 @@ class EquationPrediction(BaseModel):
|
||||
|
||||
|
||||
class PagePredictions(BaseModel):
|
||||
layout: LayoutPrediction = None
|
||||
layout: Optional[LayoutPrediction] = None
|
||||
tablestructure: Optional[TableStructurePrediction] = None
|
||||
figures_classification: Optional[FigureClassificationPrediction] = None
|
||||
equations_prediction: Optional[EquationPrediction] = None
|
||||
@ -267,7 +267,7 @@ class Page(BaseModel):
|
||||
page_no: int
|
||||
page_hash: Optional[str] = None
|
||||
size: Optional[PageSize] = None
|
||||
cells: List[Cell] = None
|
||||
cells: List[Cell] = []
|
||||
predictions: PagePredictions = PagePredictions()
|
||||
assembled: Optional[AssembledUnit] = None
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
from docling.datamodel.base_models import Page, PipelineOptions
|
||||
|
||||
|
||||
class BaseModelPipeline:
|
||||
def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
|
||||
self.model_pipe = []
|
||||
self.model_pipe: List[Callable] = []
|
||||
self.artifacts_path = artifacts_path
|
||||
self.pipeline_options = pipeline_options
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
from docling_core.types.doc.base import BaseCell, Ref, Table, TableCell
|
||||
from docling_core.types.doc.base import BaseCell, BaseText, Ref, Table, TableCell
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell
|
||||
from docling.datamodel.document import ConvertedDocument, Page
|
||||
from docling.datamodel.document import ConversionResult, Page
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@ -15,6 +15,9 @@ def _export_table_to_html(table: Table):
|
||||
# to the docling-core package.
|
||||
|
||||
def _get_tablecell_span(cell: TableCell, ix):
|
||||
if cell.spans is None:
|
||||
span = set()
|
||||
else:
|
||||
span = set([s[ix] for s in cell.spans])
|
||||
if len(span) == 0:
|
||||
return 1, None, None
|
||||
@ -24,6 +27,8 @@ def _export_table_to_html(table: Table):
|
||||
nrows = table.num_rows
|
||||
ncols = table.num_cols
|
||||
|
||||
if table.data is None:
|
||||
return ""
|
||||
for i in range(nrows):
|
||||
body += "<tr>"
|
||||
for j in range(ncols):
|
||||
@ -66,7 +71,7 @@ def _export_table_to_html(table: Table):
|
||||
|
||||
|
||||
def generate_multimodal_pages(
|
||||
doc_result: ConvertedDocument,
|
||||
doc_result: ConversionResult,
|
||||
) -> Iterable[Tuple[str, str, List[Dict[str, Any]], List[Dict[str, Any]], Page]]:
|
||||
|
||||
label_to_doclaynet = {
|
||||
@ -94,7 +99,7 @@ def generate_multimodal_pages(
|
||||
page_no = 0
|
||||
start_ix = 0
|
||||
end_ix = 0
|
||||
doc_items = []
|
||||
doc_items: List[Tuple[int, Union[BaseCell, BaseText]]] = []
|
||||
|
||||
doc = doc_result.output
|
||||
|
||||
@ -105,11 +110,11 @@ def generate_multimodal_pages(
|
||||
item_type = item.obj_type
|
||||
label = label_to_doclaynet.get(item_type, None)
|
||||
|
||||
if label is None:
|
||||
if label is None or item.prov is None or page.size is None:
|
||||
continue
|
||||
|
||||
bbox = BoundingBox.from_tuple(
|
||||
item.prov[0].bbox, origin=CoordOrigin.BOTTOMLEFT
|
||||
tuple(item.prov[0].bbox), origin=CoordOrigin.BOTTOMLEFT
|
||||
)
|
||||
new_bbox = bbox.to_top_left_origin(page_height=page.size.height).normalized(
|
||||
page_size=page.size
|
||||
@ -137,13 +142,15 @@ def generate_multimodal_pages(
|
||||
return segments
|
||||
|
||||
def _process_page_cells(page: Page):
|
||||
cells = []
|
||||
cells: List[dict] = []
|
||||
if page.size is None:
|
||||
return cells
|
||||
for cell in page.cells:
|
||||
new_bbox = cell.bbox.to_top_left_origin(
|
||||
page_height=page.size.height
|
||||
).normalized(page_size=page.size)
|
||||
is_ocr = isinstance(cell, OcrCell)
|
||||
ocr_confidence = cell.confidence if is_ocr else 1.0
|
||||
ocr_confidence = cell.confidence if isinstance(cell, OcrCell) else 1.0
|
||||
cells.append(
|
||||
{
|
||||
"text": cell.text,
|
||||
@ -170,6 +177,8 @@ def generate_multimodal_pages(
|
||||
|
||||
return content_text, content_md, content_dt, page_cells, page_segments, page
|
||||
|
||||
if doc.main_text is None:
|
||||
return
|
||||
for ix, orig_item in enumerate(doc.main_text):
|
||||
|
||||
item = doc._resolve_ref(orig_item) if isinstance(orig_item, Ref) else orig_item
|
||||
|
38
poetry.lock
generated
38
poetry.lock
generated
@ -163,8 +163,8 @@ files = [
|
||||
lazy-object-proxy = ">=1.4.0"
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
||||
wrapt = [
|
||||
{version = ">=1.14,<2", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.11,<2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.14,<2", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2867,9 +2867,9 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
]
|
||||
|
||||
@ -2924,8 +2924,8 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
@ -2957,6 +2957,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
|
||||
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
|
||||
xml = ["lxml (>=4.9.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "pandas-stubs"
|
||||
version = "2.2.2.240909"
|
||||
description = "Type annotations for pandas"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
files = [
|
||||
{file = "pandas_stubs-2.2.2.240909-py3-none-any.whl", hash = "sha256:e230f5fa4065f9417804f4d65cd98f86c002efcc07933e8abcd48c3fad9c30a2"},
|
||||
{file = "pandas_stubs-2.2.2.240909.tar.gz", hash = "sha256:3c0951a2c3e45e3475aed9d80b7147ae82f176b9e42e9fb321cfdebf3d411b3d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.23.5"
|
||||
types-pytz = ">=2022.1.1"
|
||||
|
||||
[[package]]
|
||||
name = "parso"
|
||||
version = "0.8.4"
|
||||
@ -3354,8 +3369,8 @@ files = [
|
||||
annotated-types = ">=0.6.0"
|
||||
pydantic-core = "2.23.3"
|
||||
typing-extensions = [
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@ -3523,8 +3538,8 @@ files = [
|
||||
astroid = ">=2.15.8,<=2.17.0-dev0"
|
||||
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
|
||||
dill = [
|
||||
{version = ">=0.3.6", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=0.2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=0.3.6", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
isort = ">=4.2.5,<6"
|
||||
mccabe = ">=0.6,<0.8"
|
||||
@ -4939,6 +4954,17 @@ rfc3986 = ">=1.4.0"
|
||||
tqdm = ">=4.14"
|
||||
urllib3 = ">=1.26.0"
|
||||
|
||||
[[package]]
|
||||
name = "types-pytz"
|
||||
version = "2024.1.0.20240417"
|
||||
description = "Typing stubs for pytz"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "types-pytz-2024.1.0.20240417.tar.gz", hash = "sha256:6810c8a1f68f21fdf0f4f374a432487c77645a0ac0b31de4bf4690cf21ad3981"},
|
||||
{file = "types_pytz-2024.1.0.20240417-py3-none-any.whl", hash = "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.32.0.20240907"
|
||||
@ -5390,4 +5416,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "a7eb01cef3b4b5ce9d847969349351f2f75df8f3afab0db6224ce1c96e77f5ab"
|
||||
content-hash = "65b40f3fadf237f82bc024896fc3c7f23f4a5640a02e70e20db60b208f3de575"
|
||||
|
@ -51,6 +51,7 @@ pytest-xdist = "^3.3.1"
|
||||
types-requests = "^2.31.0.2"
|
||||
flake8-pyproject = "^1.2.3"
|
||||
pylint = "^2.17.5"
|
||||
pandas-stubs = "^2.2.2.240909"
|
||||
|
||||
|
||||
[tool.poetry.group.examples.dependencies]
|
||||
@ -76,6 +77,14 @@ pretty = true
|
||||
no_implicit_optional = true
|
||||
python_version = "3.10"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"docling_parse.*",
|
||||
"pypdfium2.*",
|
||||
"networkx.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.flake8]
|
||||
max-line-length = 88
|
||||
extend-ignore = ["E203", "E501"]
|
||||
|
@ -45,6 +45,8 @@ def verify_cells(doc_pred_pages: List[Page], doc_true_pages: List[Page]):
|
||||
|
||||
|
||||
def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument):
|
||||
assert doc_true.main_text is not None, "doc_true cannot be None"
|
||||
assert doc_pred.main_text is not None, "doc_true cannot be None"
|
||||
|
||||
assert len(doc_true.main_text) == len(
|
||||
doc_pred.main_text
|
||||
@ -68,6 +70,13 @@ def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument):
|
||||
|
||||
|
||||
def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
|
||||
if doc_true.tables is None:
|
||||
# No tables to check
|
||||
assert doc_pred.tables is None, "not expecting any table on this document"
|
||||
return True
|
||||
|
||||
assert doc_pred.tables is not None, "no tables predicted, but expected in doc_true"
|
||||
|
||||
assert len(doc_true.tables) == len(
|
||||
doc_pred.tables
|
||||
), "document has different count of tables than expected."
|
||||
@ -82,6 +91,8 @@ def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
|
||||
true_item.num_cols == pred_item.num_cols
|
||||
), "table does not have the same #-cols"
|
||||
|
||||
assert true_item.data is not None, "documents are expected to have table data"
|
||||
assert pred_item.data is not None, "documents are expected to have table data"
|
||||
for i, row in enumerate(true_item.data):
|
||||
for j, col in enumerate(true_item.data[i]):
|
||||
|
||||
@ -135,7 +146,7 @@ def verify_conversion_result(
|
||||
doc_true_pages = PageList.validate_json(fr.read())
|
||||
|
||||
with open(json_path, "r") as fr:
|
||||
doc_true = DsDocument.model_validate_json(fr.read())
|
||||
doc_true: DsDocument = DsDocument.model_validate_json(fr.read())
|
||||
|
||||
with open(md_path, "r") as fr:
|
||||
doc_true_md = fr.read()
|
||||
|
Loading…
Reference in New Issue
Block a user