mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-16 16:48:21 +00:00
ci: add coverage and ruff (#1383)
* add coverage calculation and push Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * new codecov version and usage of token Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * enable ruff formatter instead of black and isort Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * apply ruff lint fixes Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * apply ruff unsafe fixes Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add removed imports Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * runs 1 on linter issues Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * finalize linter fixes Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * Update pyproject.toml Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com> Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
@@ -10,7 +10,6 @@ from docling.utils.profiling import TimeRecorder
|
||||
|
||||
|
||||
class ApiVlmModel(BasePageModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, Iterable, Optional, Protocol, Type
|
||||
from collections.abc import Iterable
|
||||
from typing import Generic, Optional, Protocol, Type
|
||||
|
||||
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
||||
from typing_extensions import TypeVar
|
||||
@@ -29,7 +30,6 @@ EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
||||
|
||||
|
||||
class GenericEnrichmentModel(ABC, Generic[EnrichElementT]):
|
||||
|
||||
elements_batch_size: int = settings.perf.elements_batch_size
|
||||
|
||||
@abstractmethod
|
||||
@@ -50,7 +50,6 @@ class GenericEnrichmentModel(ABC, Generic[EnrichElementT]):
|
||||
|
||||
|
||||
class BaseEnrichmentModel(GenericEnrichmentModel[NodeItem]):
|
||||
|
||||
def prepare_element(
|
||||
self, conv_res: ConversionResult, element: NodeItem
|
||||
) -> Optional[NodeItem]:
|
||||
@@ -62,7 +61,6 @@ class BaseEnrichmentModel(GenericEnrichmentModel[NodeItem]):
|
||||
class BaseItemAndImageEnrichmentModel(
|
||||
GenericEnrichmentModel[ItemAndImageEnrichmentElement]
|
||||
):
|
||||
|
||||
images_scale: float
|
||||
expansion_factor: float = 0.0
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import copy
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
from docling_core.types.doc.page import BoundingRectangle, PdfTextCell, TextCell
|
||||
from PIL import Image, ImageDraw
|
||||
from rtree import index
|
||||
from scipy.ndimage import binary_dilation, find_objects, label
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import re
|
||||
from collections import Counter
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import (
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import (
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
@@ -58,12 +59,10 @@ class EasyOcrModel(BaseOcrModel):
|
||||
device = decide_device(accelerator_options.device)
|
||||
# Enable easyocr GPU if running on CUDA, MPS
|
||||
use_gpu = any(
|
||||
[
|
||||
device.startswith(x)
|
||||
for x in [
|
||||
AcceleratorDevice.CUDA.value,
|
||||
AcceleratorDevice.MPS.value,
|
||||
]
|
||||
device.startswith(x)
|
||||
for x in [
|
||||
AcceleratorDevice.CUDA.value,
|
||||
AcceleratorDevice.MPS.value,
|
||||
]
|
||||
)
|
||||
else:
|
||||
@@ -98,8 +97,10 @@ class EasyOcrModel(BaseOcrModel):
|
||||
progress: bool = False,
|
||||
) -> Path:
|
||||
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
|
||||
from easyocr.config import detection_models as det_models_dict
|
||||
from easyocr.config import recognition_models as rec_models_dict
|
||||
from easyocr.config import (
|
||||
detection_models as det_models_dict,
|
||||
recognition_models as rec_models_dict,
|
||||
)
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
|
||||
@@ -126,13 +127,11 @@ class EasyOcrModel(BaseOcrModel):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
|
||||
@@ -9,7 +9,7 @@ from docling.models.factories.picture_description_factory import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def get_ocr_factory(allow_external_plugins: bool = False) -> OcrFactory:
|
||||
factory = OcrFactory()
|
||||
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
|
||||
@@ -17,7 +17,7 @@ def get_ocr_factory(allow_external_plugins: bool = False) -> OcrFactory:
|
||||
return factory
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def get_picture_description_factory(
|
||||
allow_external_plugins: bool = False,
|
||||
) -> PictureDescriptionFactory:
|
||||
|
||||
@@ -33,7 +33,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||
|
||||
@property
|
||||
def registered_kind(self) -> list[str]:
|
||||
return list(opt.kind for opt in self._classes.keys())
|
||||
return [opt.kind for opt in self._classes.keys()]
|
||||
|
||||
def get_enum(self) -> enum.Enum:
|
||||
return enum.Enum(
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
HuggingFaceVlmOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceMlxModel(BasePageModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@@ -32,7 +29,6 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
if self.enabled:
|
||||
|
||||
try:
|
||||
from mlx_vlm import generate, load # type: ignore
|
||||
from mlx_vlm.prompt_utils import apply_chat_template # type: ignore
|
||||
@@ -125,6 +121,8 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
generation_time = time.time() - start_time
|
||||
page_tags = output
|
||||
|
||||
_log.debug(f"Generation time {generation_time:.2f} seconds.")
|
||||
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
HuggingFaceVlmOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
@@ -19,7 +18,6 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceVlmModel(BasePageModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
@@ -42,7 +40,7 @@ class HuggingFaceVlmModel(BasePageModel):
|
||||
device = decide_device(accelerator_options.device)
|
||||
self.device = device
|
||||
|
||||
_log.debug("Available device for HuggingFace VLM: {}".format(device))
|
||||
_log.debug(f"Available device for HuggingFace VLM: {device}")
|
||||
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
@@ -168,6 +166,10 @@ class HuggingFaceVlmModel(BasePageModel):
|
||||
num_tokens = len(generated_ids[0])
|
||||
page_tags = generated_texts
|
||||
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
)
|
||||
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from docling_core.types.doc import DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
@@ -142,7 +143,6 @@ class LayoutModel(BasePageModel):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Tuple, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
from docling_core.types.doc.page import BoundingRectangle, TextCell
|
||||
@@ -41,7 +42,7 @@ class OcrMacModel(BaseOcrModel):
|
||||
|
||||
if self.enabled:
|
||||
if "darwin" != sys.platform:
|
||||
raise RuntimeError(f"OcrMac is only supported on Mac.")
|
||||
raise RuntimeError("OcrMac is only supported on Mac.")
|
||||
install_errmsg = (
|
||||
"ocrmac is not correctly installed. "
|
||||
"Please install it via `pip install ocrmac` to use this OCR engine. "
|
||||
@@ -58,7 +59,6 @@ class OcrMacModel(BaseOcrModel):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
@@ -69,7 +69,6 @@ class OcrMacModel(BaseOcrModel):
|
||||
yield page
|
||||
else:
|
||||
with TimeRecorder(conv_res, "ocr"):
|
||||
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
all_ocr_cells = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterable, List
|
||||
from collections.abc import Iterable
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -53,9 +54,9 @@ class PageAssembleModel(BasePageModel):
|
||||
sanitized_text = "".join(lines)
|
||||
|
||||
# Text normalization
|
||||
sanitized_text = sanitized_text.replace("⁄", "/")
|
||||
sanitized_text = sanitized_text.replace("’", "'")
|
||||
sanitized_text = sanitized_text.replace("‘", "'")
|
||||
sanitized_text = sanitized_text.replace("⁄", "/") # noqa: RUF001
|
||||
sanitized_text = sanitized_text.replace("’", "'") # noqa: RUF001
|
||||
sanitized_text = sanitized_text.replace("‘", "'") # noqa: RUF001
|
||||
sanitized_text = sanitized_text.replace("“", '"')
|
||||
sanitized_text = sanitized_text.replace("”", '"')
|
||||
sanitized_text = sanitized_text.replace("•", "·")
|
||||
@@ -71,7 +72,6 @@ class PageAssembleModel(BasePageModel):
|
||||
yield page
|
||||
else:
|
||||
with TimeRecorder(conv_res, "page_assemble"):
|
||||
|
||||
assert page.predictions.layout is not None
|
||||
|
||||
# assembles some JSON output page by page.
|
||||
@@ -83,7 +83,6 @@ class PageAssembleModel(BasePageModel):
|
||||
for cluster in page.predictions.layout.clusters:
|
||||
# _log.info("Cluster label seen:", cluster.label)
|
||||
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
|
||||
|
||||
textlines = [
|
||||
cell.text.replace("\x02", "-").strip()
|
||||
for cell in cluster.cells
|
||||
@@ -109,9 +108,7 @@ class PageAssembleModel(BasePageModel):
|
||||
tbl = page.predictions.tablestructure.table_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
if (
|
||||
not tbl
|
||||
): # fallback: add table without structure, if it isn't present
|
||||
if not tbl: # fallback: add table without structure, if it isn't present
|
||||
tbl = Table(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
@@ -130,9 +127,7 @@ class PageAssembleModel(BasePageModel):
|
||||
fig = page.predictions.figures_classification.figure_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
if (
|
||||
not fig
|
||||
): # fallback: add figure without classification, if it isn't present
|
||||
if not fig: # fallback: add figure without classification, if it isn't present
|
||||
fig = FigureElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import ImageDraw
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Type, Union
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, Optional, Type, Union
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from docling_core.types.doc import (
|
||||
DoclingDocument,
|
||||
NodeItem,
|
||||
PictureClassificationClass,
|
||||
PictureItem,
|
||||
)
|
||||
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Type, Union
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -13,7 +14,6 @@ from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
|
||||
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
|
||||
return PictureDescriptionVlmOptions
|
||||
@@ -36,7 +36,6 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||
self.options: PictureDescriptionVlmOptions
|
||||
|
||||
if self.enabled:
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models(repo_id=self.options.repo_id)
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
@@ -74,13 +75,11 @@ class RapidOcrModel(BaseOcrModel):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
import copy
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from docling_core.types.doc import (
|
||||
BoundingBox,
|
||||
CoordOrigin,
|
||||
DocItem,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
DocumentOrigin,
|
||||
@@ -17,13 +12,10 @@ from docling_core.types.doc import (
|
||||
TableData,
|
||||
)
|
||||
from docling_core.types.doc.document import ContentLayer
|
||||
from docling_core.types.legacy_doc.base import Ref
|
||||
from docling_core.types.legacy_doc.document import BaseText
|
||||
from docling_ibm_models.reading_order.reading_order_rb import (
|
||||
PageElement as ReadingOrderPageElement,
|
||||
ReadingOrderPredictor,
|
||||
)
|
||||
from docling_ibm_models.reading_order.reading_order_rb import ReadingOrderPredictor
|
||||
from PIL import ImageDraw
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
@@ -35,7 +27,6 @@ from docling.datamodel.base_models import (
|
||||
TextElement,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
|
||||
@@ -53,12 +44,10 @@ class ReadingOrderModel:
|
||||
def _assembled_to_readingorder_elements(
|
||||
self, conv_res: ConversionResult
|
||||
) -> List[ReadingOrderPageElement]:
|
||||
|
||||
elements: List[ReadingOrderPageElement] = []
|
||||
page_no_to_pages = {p.page_no: p for p in conv_res.pages}
|
||||
|
||||
for element in conv_res.assembled.elements:
|
||||
|
||||
page_height = page_no_to_pages[element.page_no].size.height # type: ignore
|
||||
bbox = element.cluster.bbox.to_bottom_left_origin(page_height)
|
||||
text = element.text or ""
|
||||
@@ -84,7 +73,6 @@ class ReadingOrderModel:
|
||||
def _add_child_elements(
|
||||
self, element: BasePageElement, doc_item: NodeItem, doc: DoclingDocument
|
||||
):
|
||||
|
||||
child: Cluster
|
||||
for child in element.cluster.children:
|
||||
c_label = child.label
|
||||
@@ -110,7 +98,7 @@ class ReadingOrderModel:
|
||||
else:
|
||||
doc.add_text(parent=doc_item, label=c_label, text=c_text, prov=c_prov)
|
||||
|
||||
def _readingorder_elements_to_docling_doc(
|
||||
def _readingorder_elements_to_docling_doc( # noqa: C901
|
||||
self,
|
||||
conv_res: ConversionResult,
|
||||
ro_elements: List[ReadingOrderPageElement],
|
||||
@@ -118,7 +106,6 @@ class ReadingOrderModel:
|
||||
el_to_footnotes_mapping: Dict[int, List[int]],
|
||||
el_merges_mapping: Dict[int, List[int]],
|
||||
) -> DoclingDocument:
|
||||
|
||||
id_to_elem = {
|
||||
RefItem(cref=f"#/{elem.page_no}/{elem.cluster.id}").cref: elem
|
||||
for elem in conv_res.assembled.elements
|
||||
@@ -192,7 +179,6 @@ class ReadingOrderModel:
|
||||
|
||||
code_item.footnotes.append(new_footnote_item.get_ref())
|
||||
else:
|
||||
|
||||
new_item, current_list = self._handle_text_element(
|
||||
element, out_doc, current_list, page_height
|
||||
)
|
||||
@@ -206,7 +192,6 @@ class ReadingOrderModel:
|
||||
)
|
||||
|
||||
elif isinstance(element, Table):
|
||||
|
||||
tbl_data = TableData(
|
||||
num_rows=element.num_rows,
|
||||
num_cols=element.num_cols,
|
||||
@@ -342,12 +327,12 @@ class ReadingOrderModel:
|
||||
return new_item, current_list
|
||||
|
||||
def _merge_elements(self, element, merged_elem, new_item, page_height):
|
||||
assert isinstance(
|
||||
merged_elem, type(element)
|
||||
), "Merged element must be of same type as element."
|
||||
assert (
|
||||
merged_elem.label == new_item.label
|
||||
), "Labels of merged elements must match."
|
||||
assert isinstance(merged_elem, type(element)), (
|
||||
"Merged element must be of same type as element."
|
||||
)
|
||||
assert merged_elem.label == new_item.label, (
|
||||
"Labels of merged elements must match."
|
||||
)
|
||||
prov = ProvenanceItem(
|
||||
page_no=element.page_no + 1,
|
||||
charspan=(
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import copy
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
||||
from docling_core.types.doc.page import (
|
||||
BoundingRectangle,
|
||||
SegmentedPdfPage,
|
||||
TextCellUnit,
|
||||
)
|
||||
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
|
||||
@@ -44,7 +44,6 @@ class TableStructureModel(BasePageModel):
|
||||
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models() / self._model_path
|
||||
else:
|
||||
@@ -175,7 +174,6 @@ class TableStructureModel(BasePageModel):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
@@ -186,7 +184,6 @@ class TableStructureModel(BasePageModel):
|
||||
yield page
|
||||
else:
|
||||
with TimeRecorder(conv_res, "table_structure"):
|
||||
|
||||
assert page.predictions.layout is not None
|
||||
assert page.size is not None
|
||||
|
||||
@@ -260,7 +257,6 @@ class TableStructureModel(BasePageModel):
|
||||
table_out = tf_output[0]
|
||||
table_cells = []
|
||||
for element in table_out["tf_responses"]:
|
||||
|
||||
if not self.do_cell_matching:
|
||||
the_bbox = BoundingBox.model_validate(
|
||||
element["bbox"]
|
||||
|
||||
@@ -3,9 +3,10 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from subprocess import DEVNULL, PIPE, Popen
|
||||
from typing import Iterable, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pandas as pd
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
@@ -63,8 +64,7 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
)
|
||||
|
||||
def _get_name_and_version(self) -> Tuple[str, str]:
|
||||
|
||||
if self._name != None and self._version != None:
|
||||
if self._name is not None and self._version is not None:
|
||||
return self._name, self._version # type: ignore
|
||||
|
||||
cmd = [self.options.tesseract_cmd, "--version"]
|
||||
@@ -125,14 +125,16 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
# _log.info(decoded_data)
|
||||
|
||||
# Read the TSV file generated by Tesseract
|
||||
df = pd.read_csv(io.StringIO(decoded_data), quoting=csv.QUOTE_NONE, sep="\t")
|
||||
df_result = pd.read_csv(
|
||||
io.StringIO(decoded_data), quoting=csv.QUOTE_NONE, sep="\t"
|
||||
)
|
||||
|
||||
# Display the dataframe (optional)
|
||||
# _log.info("df: ", df.head())
|
||||
|
||||
# Filter rows that contain actual text (ignore header or empty rows)
|
||||
df_filtered = df[
|
||||
df["text"].notnull() & (df["text"].apply(str).str.strip() != "")
|
||||
df_filtered = df_result[
|
||||
df_result["text"].notna() & (df_result["text"].apply(str).str.strip() != "")
|
||||
]
|
||||
|
||||
return df_filtered
|
||||
@@ -149,10 +151,10 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
proc = Popen(cmd, stdout=PIPE, stderr=DEVNULL)
|
||||
output, _ = proc.communicate()
|
||||
decoded_data = output.decode("utf-8")
|
||||
df = pd.read_csv(
|
||||
df_detected = pd.read_csv(
|
||||
io.StringIO(decoded_data), sep=":", header=None, names=["key", "value"]
|
||||
)
|
||||
scripts = df.loc[df["key"] == "Script"].value.tolist()
|
||||
scripts = df_detected.loc[df_detected["key"] == "Script"].value.tolist()
|
||||
if len(scripts) == 0:
|
||||
_log.warning("Tesseract cannot detect the script of the page")
|
||||
return None
|
||||
@@ -183,11 +185,11 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
proc = Popen(cmd, stdout=PIPE, stderr=DEVNULL)
|
||||
output, _ = proc.communicate()
|
||||
decoded_data = output.decode("utf-8")
|
||||
df = pd.read_csv(io.StringIO(decoded_data), header=None)
|
||||
self._tesseract_languages = df[0].tolist()[1:]
|
||||
df_list = pd.read_csv(io.StringIO(decoded_data), header=None)
|
||||
self._tesseract_languages = df_list[0].tolist()[1:]
|
||||
|
||||
# Decide the script prefix
|
||||
if any([l.startswith("script/") for l in self._tesseract_languages]):
|
||||
if any(lang.startswith("script/") for lang in self._tesseract_languages):
|
||||
script_prefix = "script/"
|
||||
else:
|
||||
script_prefix = ""
|
||||
@@ -197,7 +199,6 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
@@ -225,19 +226,19 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
fname = image_file.name
|
||||
high_res_image.save(image_file)
|
||||
|
||||
df = self._run_tesseract(fname)
|
||||
df_result = self._run_tesseract(fname)
|
||||
finally:
|
||||
if os.path.exists(fname):
|
||||
os.remove(fname)
|
||||
|
||||
# _log.info(df)
|
||||
# _log.info(df_result)
|
||||
|
||||
# Print relevant columns (bounding box and text)
|
||||
for ix, row in df.iterrows():
|
||||
for ix, row in df_result.iterrows():
|
||||
text = row["text"]
|
||||
conf = row["conf"]
|
||||
|
||||
l = float(row["left"])
|
||||
l = float(row["left"]) # noqa: E741
|
||||
b = float(row["top"])
|
||||
w = float(row["width"])
|
||||
h = float(row["height"])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
from docling_core.types.doc.page import BoundingRectangle, TextCell
|
||||
@@ -37,9 +38,6 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
self.options: TesseractOcrOptions
|
||||
|
||||
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
|
||||
self.reader = None
|
||||
self.osd_reader = None
|
||||
self.script_readers: dict[str, tesserocr.PyTessBaseAPI] = {}
|
||||
|
||||
if self.enabled:
|
||||
install_errmsg = (
|
||||
@@ -64,7 +62,7 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
raise ImportError(install_errmsg)
|
||||
try:
|
||||
tesseract_version = tesserocr.tesseract_version()
|
||||
except:
|
||||
except Exception:
|
||||
raise ImportError(install_errmsg)
|
||||
|
||||
_, self._tesserocr_languages = tesserocr.get_languages()
|
||||
@@ -75,7 +73,7 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
_log.debug("Initializing TesserOCR: %s", tesseract_version)
|
||||
lang = "+".join(self.options.lang)
|
||||
|
||||
if any([l.startswith("script/") for l in self._tesserocr_languages]):
|
||||
if any(lang.startswith("script/") for lang in self._tesserocr_languages):
|
||||
self.script_prefix = "script/"
|
||||
else:
|
||||
self.script_prefix = ""
|
||||
@@ -86,6 +84,10 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
"oem": tesserocr.OEM.DEFAULT,
|
||||
}
|
||||
|
||||
self.reader = None
|
||||
self.osd_reader = None
|
||||
self.script_readers: dict[str, tesserocr.PyTessBaseAPI] = {}
|
||||
|
||||
if self.options.path is not None:
|
||||
tesserocr_kwargs["path"] = self.options.path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user