mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-09 13:18:24 +00:00
feat: Add adaptive OCR, factor out treatment of OCR areas and cell filtering (#38)
* Introduce adaptive OCR Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Factor out BaseOcrModel, add docling-parse backend tests, fixes * Make easyocr default dep Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -18,6 +18,10 @@ class PdfPageBackend(ABC):
|
||||
def get_text_cells(self) -> Iterable["Cell"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_bitmap_rects(self, scale: int = 1) -> Iterable["BoundingBox"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_page_image(
|
||||
self, scale: int = 1, cropbox: Optional["BoundingBox"] = None
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import time
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
import pypdfium2 as pdfium
|
||||
from docling_parse.docling_parse import pdf_parser
|
||||
@@ -43,7 +43,7 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
r=x1 * scale * page_size.width / parser_width,
|
||||
t=y1 * scale * page_size.height / parser_height,
|
||||
coord_origin=CoordOrigin.BOTTOMLEFT,
|
||||
).to_top_left_origin(page_size.height * scale)
|
||||
).to_top_left_origin(page_height=page_size.height * scale)
|
||||
|
||||
overlap_frac = cell_bbox.intersection_area_with(bbox) / cell_bbox.area()
|
||||
|
||||
@@ -66,6 +66,12 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
for i in range(len(self._dpage["cells"])):
|
||||
rect = self._dpage["cells"][i]["box"]["device"]
|
||||
x0, y0, x1, y1 = rect
|
||||
|
||||
if x1 < x0:
|
||||
x0, x1 = x1, x0
|
||||
if y1 < y0:
|
||||
y0, y1 = y1, y0
|
||||
|
||||
text_piece = self._dpage["cells"][i]["content"]["rnormalized"]
|
||||
cells.append(
|
||||
Cell(
|
||||
@@ -108,6 +114,20 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
|
||||
return cells
|
||||
|
||||
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
|
||||
AREA_THRESHOLD = 32 * 32
|
||||
|
||||
for i in range(len(self._dpage["images"])):
|
||||
bitmap = self._dpage["images"][i]
|
||||
cropbox = BoundingBox.from_tuple(
|
||||
bitmap["box"], origin=CoordOrigin.BOTTOMLEFT
|
||||
).to_top_left_origin(self.get_size().height)
|
||||
|
||||
if cropbox.area() > AREA_THRESHOLD:
|
||||
cropbox = cropbox.scaled(scale=scale)
|
||||
|
||||
yield cropbox
|
||||
|
||||
def get_page_image(
|
||||
self, scale: int = 1, cropbox: Optional[BoundingBox] = None
|
||||
) -> Image.Image:
|
||||
@@ -173,7 +193,7 @@ class DoclingParseDocumentBackend(PdfDocumentBackend):
|
||||
def page_count(self) -> int:
|
||||
return len(self._parser_doc["pages"])
|
||||
|
||||
def load_page(self, page_no: int) -> PdfPage:
|
||||
def load_page(self, page_no: int) -> DoclingParsePageBackend:
|
||||
return DoclingParsePageBackend(
|
||||
self._pdoc[page_no], self._parser_doc["pages"][page_no]
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
import pypdfium2 as pdfium
|
||||
import pypdfium2.raw as pdfium_c
|
||||
from PIL import Image, ImageDraw
|
||||
from pypdfium2 import PdfPage
|
||||
|
||||
@@ -17,6 +18,19 @@ class PyPdfiumPageBackend(PdfPageBackend):
|
||||
self._ppage = page_obj
|
||||
self.text_page = None
|
||||
|
||||
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
|
||||
AREA_THRESHOLD = 32 * 32
|
||||
for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]):
|
||||
pos = obj.get_pos()
|
||||
cropbox = BoundingBox.from_tuple(
|
||||
pos, origin=CoordOrigin.BOTTOMLEFT
|
||||
).to_top_left_origin(page_height=self.get_size().height)
|
||||
|
||||
if cropbox.area() > AREA_THRESHOLD:
|
||||
cropbox = cropbox.scaled(scale=scale)
|
||||
|
||||
yield cropbox
|
||||
|
||||
def get_text_in_rect(self, bbox: BoundingBox) -> str:
|
||||
if not self.text_page:
|
||||
self.text_page = self._ppage.get_textpage()
|
||||
@@ -208,7 +222,7 @@ class PyPdfiumDocumentBackend(PdfDocumentBackend):
|
||||
def page_count(self) -> int:
|
||||
return len(self._pdoc)
|
||||
|
||||
def load_page(self, page_no: int) -> PdfPage:
|
||||
def load_page(self, page_no: int) -> PyPdfiumPageBackend:
|
||||
return PyPdfiumPageBackend(self._pdoc[page_no])
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
|
||||
@@ -68,13 +68,21 @@ class BoundingBox(BaseModel):
|
||||
@classmethod
|
||||
def from_tuple(cls, coord: Tuple[float], origin: CoordOrigin):
|
||||
if origin == CoordOrigin.TOPLEFT:
|
||||
return BoundingBox(
|
||||
l=coord[0], t=coord[1], r=coord[2], b=coord[3], coord_origin=origin
|
||||
)
|
||||
l, t, r, b = coord[0], coord[1], coord[2], coord[3]
|
||||
if r < l:
|
||||
l, r = r, l
|
||||
if b < t:
|
||||
b, t = t, b
|
||||
|
||||
return BoundingBox(l=l, t=t, r=r, b=b, coord_origin=origin)
|
||||
elif origin == CoordOrigin.BOTTOMLEFT:
|
||||
return BoundingBox(
|
||||
l=coord[0], b=coord[1], r=coord[2], t=coord[3], coord_origin=origin
|
||||
)
|
||||
l, b, r, t = coord[0], coord[1], coord[2], coord[3]
|
||||
if r < l:
|
||||
l, r = r, l
|
||||
if b > t:
|
||||
b, t = t, b
|
||||
|
||||
return BoundingBox(l=l, t=t, r=r, b=b, coord_origin=origin)
|
||||
|
||||
def area(self) -> float:
|
||||
return (self.r - self.l) * (self.b - self.t)
|
||||
@@ -280,7 +288,7 @@ class TableStructureOptions(BaseModel):
|
||||
|
||||
class PipelineOptions(BaseModel):
|
||||
do_table_structure: bool = True # True: perform table structure extraction
|
||||
do_ocr: bool = False # True: perform OCR, replace programmatic PDF text
|
||||
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
||||
|
||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||
|
||||
|
||||
@@ -35,8 +35,6 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentConverter:
|
||||
_layout_model_path = "model_artifacts/layout/beehive_v0.0.5"
|
||||
_table_model_path = "model_artifacts/tableformer"
|
||||
_default_download_filename = "file.pdf"
|
||||
|
||||
def __init__(
|
||||
|
||||
124
docling/models/base_ocr_model.py
Normal file
124
docling/models/base_ocr_model.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import copy
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from rtree import index
|
||||
from scipy.ndimage import find_objects, label
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOcrModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.enabled = config["enabled"]
|
||||
|
||||
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
|
||||
def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]:
|
||||
BITMAP_COVERAGE_TRESHOLD = 0.75
|
||||
|
||||
def find_ocr_rects(size, bitmap_rects):
|
||||
image = Image.new(
|
||||
"1", (round(size.width), round(size.height))
|
||||
) # '1' mode is binary
|
||||
|
||||
# Draw all bitmap rects into a binary image
|
||||
draw = ImageDraw.Draw(image)
|
||||
for rect in bitmap_rects:
|
||||
x0, y0, x1, y1 = rect.as_tuple()
|
||||
x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
|
||||
draw.rectangle([(x0, y0), (x1, y1)], fill=1)
|
||||
|
||||
np_image = np.array(image)
|
||||
|
||||
# Find the connected components
|
||||
labeled_image, num_features = label(
|
||||
np_image > 0
|
||||
) # Label black (0 value) regions
|
||||
|
||||
# Find enclosing bounding boxes for each connected component.
|
||||
slices = find_objects(labeled_image)
|
||||
bounding_boxes = [
|
||||
BoundingBox(
|
||||
l=slc[1].start,
|
||||
t=slc[0].start,
|
||||
r=slc[1].stop - 1,
|
||||
b=slc[0].stop - 1,
|
||||
coord_origin=CoordOrigin.TOPLEFT,
|
||||
)
|
||||
for slc in slices
|
||||
]
|
||||
|
||||
# Compute area fraction on page covered by bitmaps
|
||||
area_frac = np.sum(np_image > 0) / (size.width * size.height)
|
||||
|
||||
return (area_frac, bounding_boxes) # fraction covered # boxes
|
||||
|
||||
bitmap_rects = page._backend.get_bitmap_rects()
|
||||
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
|
||||
|
||||
# return full-page rectangle if sufficiently covered with bitmaps
|
||||
if coverage > BITMAP_COVERAGE_TRESHOLD:
|
||||
return [
|
||||
BoundingBox(
|
||||
l=0,
|
||||
t=0,
|
||||
r=page.size.width,
|
||||
b=page.size.height,
|
||||
coord_origin=CoordOrigin.TOPLEFT,
|
||||
)
|
||||
]
|
||||
# return individual rectangles if the bitmap coverage is smaller
|
||||
elif coverage < BITMAP_COVERAGE_TRESHOLD:
|
||||
return ocr_rects
|
||||
|
||||
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
|
||||
def filter_ocr_cells(self, ocr_cells, programmatic_cells):
|
||||
# Create R-tree index for programmatic cells
|
||||
p = index.Property()
|
||||
p.dimension = 2
|
||||
idx = index.Index(properties=p)
|
||||
for i, cell in enumerate(programmatic_cells):
|
||||
idx.insert(i, cell.bbox.as_tuple())
|
||||
|
||||
def is_overlapping_with_existing_cells(ocr_cell):
|
||||
# Query the R-tree to get overlapping rectangles
|
||||
possible_matches_index = list(idx.intersection(ocr_cell.bbox.as_tuple()))
|
||||
|
||||
return (
|
||||
len(possible_matches_index) > 0
|
||||
) # this is a weak criterion but it works.
|
||||
|
||||
filtered_ocr_cells = [
|
||||
rect for rect in ocr_cells if not is_overlapping_with_existing_cells(rect)
|
||||
]
|
||||
return filtered_ocr_cells
|
||||
|
||||
def draw_ocr_rects_and_cells(self, page, ocr_rects):
|
||||
image = copy.deepcopy(page.image)
|
||||
draw = ImageDraw.Draw(image, "RGBA")
|
||||
|
||||
# Draw OCR rectangles as yellow filled rect
|
||||
for rect in ocr_rects:
|
||||
x0, y0, x1, y1 = rect.as_tuple()
|
||||
shade_color = (255, 255, 0, 40) # transparent yellow
|
||||
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)
|
||||
|
||||
# Draw OCR and programmatic cells
|
||||
for tc in page.cells:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
color = "red"
|
||||
if isinstance(tc, OcrCell):
|
||||
color = "magenta"
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
|
||||
image.show()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
pass
|
||||
@@ -1,20 +1,18 @@
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyOcrModel:
|
||||
class EasyOcrModel(BaseOcrModel):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.enabled = config["enabled"]
|
||||
super().__init__(config)
|
||||
|
||||
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
|
||||
|
||||
if self.enabled:
|
||||
@@ -29,49 +27,44 @@ class EasyOcrModel:
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
# rects = page._fpage.
|
||||
high_res_image = page.get_image(scale=self.scale)
|
||||
im = numpy.array(high_res_image)
|
||||
result = self.reader.readtext(im)
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
del high_res_image
|
||||
del im
|
||||
|
||||
cells = [
|
||||
OcrCell(
|
||||
id=ix,
|
||||
text=line[1],
|
||||
confidence=line[2],
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=(
|
||||
line[0][0][0] / self.scale,
|
||||
line[0][0][1] / self.scale,
|
||||
line[0][2][0] / self.scale,
|
||||
line[0][2][1] / self.scale,
|
||||
),
|
||||
origin=CoordOrigin.TOPLEFT,
|
||||
),
|
||||
all_ocr_cells = []
|
||||
for ocr_rect in ocr_rects:
|
||||
high_res_image = page._backend.get_page_image(
|
||||
scale=self.scale, cropbox=ocr_rect
|
||||
)
|
||||
for ix, line in enumerate(result)
|
||||
]
|
||||
im = numpy.array(high_res_image)
|
||||
result = self.reader.readtext(im)
|
||||
|
||||
page.cells = cells # For now, just overwrites all digital cells.
|
||||
del high_res_image
|
||||
del im
|
||||
|
||||
cells = [
|
||||
OcrCell(
|
||||
id=ix,
|
||||
text=line[1],
|
||||
confidence=line[2],
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=(
|
||||
(line[0][0][0] / self.scale) + ocr_rect.l,
|
||||
(line[0][0][1] / self.scale) + ocr_rect.t,
|
||||
(line[0][2][0] / self.scale) + ocr_rect.l,
|
||||
(line[0][2][1] / self.scale) + ocr_rect.t,
|
||||
),
|
||||
origin=CoordOrigin.TOPLEFT,
|
||||
),
|
||||
)
|
||||
for ix, line in enumerate(result)
|
||||
]
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
## Remove OCR cells which overlap with programmatic cells.
|
||||
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)
|
||||
|
||||
page.cells.extend(filtered_ocr_cells)
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells():
|
||||
image = copy.deepcopy(page.image)
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
cell_color = (
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
)
|
||||
for tc in cells:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
||||
image.show()
|
||||
|
||||
# draw_clusters_and_cells()
|
||||
# self.draw_ocr_rects_and_cells(page, ocr_rects)
|
||||
|
||||
yield page
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import random
|
||||
from typing import Iterable, List
|
||||
|
||||
import numpy
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from docling.datamodel.base_models import Page, PipelineOptions
|
||||
from docling.datamodel.base_models import PipelineOptions
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.pipeline.base_model_pipeline import BaseModelPipeline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user